In [10]:
import pandas as pd
import numpy as np
import json
import math
import networkx as nx
import matplotlib.pyplot as plt
import geopandas as gpd
from shapely.geometry import Point
import matplotlib.animation as animation
import ast
import warnings
from datetime import datetime, timezone
from collections import defaultdict
import time

warnings.filterwarnings('ignore')

##################################################################################################
# You need to install pulp: pip install pulp
##################################################################################################

import pulp

##################################################################################################
# GLOBAL LOG FOR SHUTTLE OPERATIONS
##################################################################################################



In [None]:

shuttle_operations_log = []

def log_shuttle_operation(shuttle, current_time, action_description):
    """
    Record a shuttle's operation/state into a global log list.
    """
    log_entry = {
        'time': current_time,
        'shuttle_id': shuttle.shuttle_id,
        'location': shuttle.location_node,  # (lat, lon)
        'battery_level': shuttle.battery_level,
        'status': shuttle.status,
        'onboard_passengers': shuttle.assigned_passengers.copy(),
        'action': action_description,
    }
    shuttle_operations_log.append(log_entry)


##################################################################################################
# SHUTTLE CLASS
##################################################################################################

class Shuttle:
    """
    Represents an on-demand electric shuttle operating within a specific suburb. 
    Allows multi-passenger assignment (up to self.capacity).
    """

    def __init__(
        self,
        shuttle_id: int, 
        suburb_id: int,
        depot_node: tuple,
        capacity: int = 6,
        battery_capacity: float = 100.0,
        initial_battery_level: float = 100.0,
        speed: float = 30.0,
        charging_rate: float = 10.0
    ):
        self.shuttle_id = shuttle_id
        self.suburb_id = suburb_id
        self.depot_node = depot_node
        self.capacity = capacity

        self.battery_capacity = battery_capacity
        self.battery_level = initial_battery_level
        self.speed = speed
        self.charging_rate = charging_rate
        
        self.location_node = depot_node
        self.assigned_passengers = []
        self.status = 'idle'
        self.route = []

    def has_capacity_for(self, n_passengers: int) -> bool:
        """Check if the shuttle has capacity for `n_passengers` more."""
        return len(self.assigned_passengers) + n_passengers <= self.capacity

    def assign_passengers(self, passenger_ids: list):
        """
        Add passenger_ids to the shuttle, if capacity is sufficient.
        Sets status to 'waiting_for_passengers'.
        """
        if self.has_capacity_for(len(passenger_ids)):
            self.assigned_passengers.extend(passenger_ids)
            self.status = 'waiting_for_passengers'
        else:
            raise ValueError("Shuttle cannot accommodate these passengers due to capacity.")

    def remove_passengers(self, passenger_ids: list):
        """
        Remove each passenger from assigned_passengers if present.
        Sets status to 'idle' if no passengers remain onboard.
        """
        for pid in passenger_ids:
            if pid in self.assigned_passengers:
                self.assigned_passengers.remove(pid)
        if len(self.assigned_passengers) == 0:
            self.status = 'idle'

    def update_location(self, new_node: tuple):
        """Update the shuttle's current location to new_node."""
        self.location_node = new_node

    def charge(self, time_elapsed: float):
        """
        Charge the shuttle's battery over `time_elapsed`.
        Sets status='charging' while charging; reverts to 'idle' if full.
        """
        if self.status != 'charging':
            self.status = 'charging'
        self.battery_level += self.charging_rate * time_elapsed
        if self.battery_level >= self.battery_capacity:
            self.battery_level = self.battery_capacity
            self.status = 'idle'

    def needs_charging(self, threshold: float = 20.0) -> bool:
        """True if battery_level < threshold."""
        return self.battery_level < threshold

    def __repr__(self):
        return (
            f"Shuttle("
            f"id={self.shuttle_id}, suburb={self.suburb_id}, loc={self.location_node}, "
            f"passengers={len(self.assigned_passengers)}, battery={self.battery_level}, "
            f"status={self.status})"
        )


##################################################################################################
# PASSENGER & LEG CREATION
##################################################################################################

def parse_coord_str(s: str) -> tuple:
    """Safely parse a string into (lat, lon)."""
    return ast.literal_eval(s)

def create_passenger_trip_legs(passenger_row: pd.Series) -> list:
    """
    Return three legs for each passenger:
    1) real_origin_node -> origin_stop (shuttle)
    2) origin_stop -> destination_stop (bus)
    3) destination_stop -> real_destination_node (shuttle)
    """
    pid = passenger_row['passenger_id']
    unix_timestamp = passenger_row['unix_timestamp']
    origin_stop = passenger_row['origin_stop_node']
    destination_stop = passenger_row['destination_stop_node']
    real_origin_node = passenger_row['real_origin_node']
    real_destination_node = passenger_row['real_destination_node']
    origin_suburb = passenger_row['origin_suburb']
    destination_suburb = passenger_row['destination_suburb']

    leg_1 = {
        'passenger_id': pid,
        'mode': 'shuttle',
        # earliest time passenger is ready for pickup
        'start_node': real_origin_node,
        'end_node': origin_stop,
        'start_time': unix_timestamp,  
        'status': 'waiting_for_shuttle',
        'suburb_id': origin_suburb
    }
    leg_2 = {
        'passenger_id': pid,
        'mode': 'bus',
        'start_node': origin_stop,
        'end_node': destination_stop,
        'start_time': float('inf'),  # to be updated upon actual bus boarding
        'status': 'waiting_for_bus'
    }
    leg_3 = {
        'passenger_id': pid,
        'mode': 'shuttle',
        'start_node': destination_stop,
        'end_node': real_destination_node,
        'start_time': float('inf'),  # to be updated upon arrival from bus
        'status': 'waiting_for_shuttle',
        'suburb_id': destination_suburb
    }

    return [leg_1, leg_2, leg_3]


##################################################################################################
# BUS TIME-EXPANDED GRAPH
##################################################################################################

def build_time_expanded_graph(bus_data: pd.DataFrame):
    """
    Build a time-expanded directed graph (nx.DiGraph) with nodes=(stop, time).
    """
    G = nx.DiGraph()
    stop_times_map = {}

    for _, row in bus_data.iterrows():
        stops = row['stops']
        arrivals = row['arrival_times']
        departures = row['departure_times']
        for i, stop in enumerate(stops):
            if stop not in stop_times_map:
                stop_times_map[stop] = set()
            stop_times_map[stop].add(arrivals[i])
            stop_times_map[stop].add(departures[i])

    # Sort the time values
    for stop in stop_times_map:
        stop_times_map[stop] = sorted(stop_times_map[stop])

    # Add nodes
    for stop, tlist in stop_times_map.items():
        for t in tlist:
            G.add_node((stop, t))

    # Add ride edges
    for _, row in bus_data.iterrows():
        stops = row['stops']
        arrivals = row['arrival_times']
        departures = row['departure_times']
        for i in range(len(stops) - 1):
            s1, s2 = stops[i], stops[i+1]
            dep_t = departures[i]
            arr_t = arrivals[i+1]
            if (s1, dep_t) in G and (s2, arr_t) in G:
                travel_time = arr_t - dep_t
                G.add_edge((s1, dep_t), (s2, arr_t), weight=travel_time)

    # Add waiting edges
    for stop, tlist in stop_times_map.items():
        for i in range(len(tlist) - 1):
            t1 = tlist[i]
            t2 = tlist[i+1]
            G.add_edge((stop, t1), (stop, t2), weight=(t2 - t1))

    return G, stop_times_map


def earliest_arrival_time(
    G: nx.DiGraph,
    stop_times_map: dict,
    origin_stop: tuple,
    origin_time: float,
    dest_stop: tuple
):
    """
    Return (board_time, arrival_time) if passenger arrives at `origin_stop` >= `origin_time`.
    The board_time is the actual departure time for the next bus >= origin_time.
    The arrival_time is the earliest arrival time at `dest_stop`.
    Return None if unreachable.
    """
    times_for_origin = stop_times_map.get(origin_stop, [])
    start_node = None
    # Find the earliest bus departure time >= origin_time
    for t in times_for_origin:
        if t >= origin_time:
            start_node = (origin_stop, t)
            break
    if start_node is None:
        return None

    dist_result = nx.single_source_dijkstra_path_length(G, source=start_node, weight='weight')
    best_arrival_time = None
    board_time = start_node[1]

    # Evaluate arrival times at dest_stop
    for (stop, t) in dist_result.keys():
        if stop == dest_stop:
            # total travel time from 'board_time' perspective
            arrival_clock = board_time + dist_result[(stop, t)]
            if (best_arrival_time is None) or (arrival_clock < best_arrival_time):
                best_arrival_time = arrival_clock

    if best_arrival_time is None:
        return None

    return (board_time, best_arrival_time)


##################################################################################################
# SHUTTLE UTILITIES
##################################################################################################

def estimate_shuttle_travel_time(G_road: nx.Graph, start_node: tuple, end_node: tuple) -> float:
    """
    Naive approach: each edge => 2 minutes. 
    """
    if start_node == end_node:
        return 0.0
    path = nx.shortest_path(G_road, source=start_node, target=end_node)
    edges_count = len(path) - 1
    return edges_count * 2.0

def reposition_shuttle(shuttle, target_node, departure_time, G_road) -> float:
    """
    Move shuttle from shuttle.location_node to target_node (empty).
    Returns the arrival time at target_node.
    """
    if shuttle.location_node == target_node:
        return departure_time

    log_shuttle_operation(shuttle, departure_time, 
        f"Start reposition from {shuttle.location_node} to {target_node}"
    )
    travel_time = estimate_shuttle_travel_time(G_road, shuttle.location_node, target_node)
    arrival_time = departure_time + travel_time

    # naive battery usage
    battery_used = travel_time / 2.0
    shuttle.battery_level = max(shuttle.battery_level - battery_used, 0)
    shuttle.update_location(target_node)

    log_shuttle_operation(shuttle, arrival_time,
        f"Arrived at {target_node} (reposition complete)"
    )
    return arrival_time

def shuttle_auto_return_if_needed(shuttle, current_time, G_road) -> float:
    """
    If the shuttle is *empty* (no passengers) AND battery < 20%,
    it must return to its depot to recharge.
    """
    if not shuttle.assigned_passengers and shuttle.needs_charging(20.0):
        arrival_depot = reposition_shuttle(shuttle, shuttle.depot_node, current_time, G_road)
        new_time = shuttle_charge_if_needed(shuttle, arrival_depot)
        return new_time
    return current_time

def shuttle_charge_if_needed(shuttle, current_time: float) -> float:
    """
    If battery < 20%, charge to full. Return new simulation time after charging.
    """
    if shuttle.needs_charging(20.0):
        log_shuttle_operation(shuttle, current_time, "Battery below threshold, begin charging")
        needed = shuttle.battery_capacity - shuttle.battery_level
        charge_time = needed / shuttle.charging_rate
        new_time = current_time + charge_time

        shuttle.charge(charge_time)
        log_shuttle_operation(shuttle, new_time, "Finished charging to full")
        return new_time
    return current_time

def shuttle_multi_pickup_and_dropoff(
    shuttle,
    passengers,
    G_road,
    departure_time: float,
    leg_type=None,
    simulation_results=None
) -> float:
    """
    Simple multi-pickup approach:
      1) For each passenger, reposition => passenger's origin => board them.
      2) Then drive all to the final drop-off (same 'end_node' in this simplified logic).
      3) Everyone alights.
      4) If empty & <20% battery => auto-return to depot.

    Returns final arrival time at the drop-off.
    """
    current_time = departure_time
    if not passengers:
        return current_time

    drop_off_node = passengers[0]['end_node']  # all have the same 'end_node' for this aggregated batch

    # Pickup each passenger in turn:
    for p in passengers:
        pid = p['passenger_id']
        origin_node = p['start_node']

        # Move to passenger's origin
        print(pid)
        pickup_arrival_time = reposition_shuttle(shuttle, origin_node, current_time, G_road)
        
        # Passenger boards => that pickup_arrival_time is the actual pickup time
        shuttle.assign_passengers([pid])
        log_shuttle_operation(
            shuttle,
            pickup_arrival_time,
            f"Passenger {pid} boards at {origin_node}"
        )

        # Update simulation results for pickup time + shuttle ID
        if simulation_results is not None and leg_type is not None:
            if leg_type == "first_mile":
                simulation_results[pid]['first_shuttle_id'] = shuttle.shuttle_id
                simulation_results[pid]['first_leg_start_time'] = pickup_arrival_time
            elif leg_type == "last_mile":
                simulation_results[pid]['last_shuttle_id'] = shuttle.shuttle_id
                simulation_results[pid]['last_leg_start_time'] = pickup_arrival_time
        
        # current_time is after we've picked up the passenger
        current_time = pickup_arrival_time

    # Now drive all to the final drop-off
    travel_time = estimate_shuttle_travel_time(G_road, shuttle.location_node, drop_off_node)
    arrival_time = current_time + travel_time
    battery_used = travel_time / 2.0
    shuttle.battery_level = max(shuttle.battery_level - battery_used, 0)
    shuttle.update_location(drop_off_node)

    log_shuttle_operation(
        shuttle,
        arrival_time,
        f"Arrived at drop-off {drop_off_node} with passengers {shuttle.assigned_passengers}"
    )

    departing_pids = [p['passenger_id'] for p in passengers]
    shuttle.remove_passengers(departing_pids)
    log_shuttle_operation(
        shuttle,
        arrival_time,
        f"Passengers {departing_pids} alight at {drop_off_node}"
    )

    # Update simulation_results for end time
    if simulation_results is not None and leg_type is not None:
        for pid in departing_pids:
            if leg_type == "first_mile":
                simulation_results[pid]['first_leg_end_time'] = arrival_time
            elif leg_type == "last_mile":
                simulation_results[pid]['last_leg_end_time'] = arrival_time

    final_time = shuttle_auto_return_if_needed(shuttle, arrival_time, G_road)
    return final_time


##################################################################################################
# MILP ASSIGNMENT
##################################################################################################

def solve_shuttle_assignment_milp(shuttles, horizon_requests):
    """
    A MILP that tries to assign each request in horizon_requests 
    to exactly one shuttle while respecting capacity constraints.

    Returns:
      - assignment: dict {(shuttle_id, passenger_id): 1 or 0}
      - unassigned_reasons: dict { passenger_id: reason_string }
    """
    import pulp

    # 1) Create a PuLP problem
    prob = pulp.LpProblem("ShuttleAssignment", pulp.LpMinimize)

    # 2) Decision variables x[k, r] in {0,1}
    x = {}
    feasible_passengers = set()  # which PIDs have >=1 feasible shuttle

    for req in horizon_requests:
        pid = req['passenger_id']
        req_suburb = req['suburb_id']
        # We'll see if there's a matching shuttle
        matched_shuttles = [s for s in shuttles if s.suburb_id == req_suburb]
        if len(matched_shuttles) > 0:
            feasible_passengers.add(pid)
            for s in matched_shuttles:
                x[(s.shuttle_id, pid)] = pulp.LpVariable(
                    f"x_{s.shuttle_id}_{pid}",
                    cat=pulp.LpBinary
                )

    # 3) Objective: maximize assigned => minimize negative of sum(x)
    prob += -1 * pulp.lpSum(x.values()), "MaximizeServed"

    # 4) Constraints
    request_ids = set(r['passenger_id'] for r in horizon_requests)
    # (A) each request can be served by at most 1 shuttle
    for pid in request_ids:
        relevant_vars = [(shid, p) for (shid, p) in x.keys() if p == pid]
        if relevant_vars:
            prob += pulp.lpSum([x[var] for var in relevant_vars]) <= 1

    # (B) capacity constraints
    for s in shuttles:
        terms = []
        for req in horizon_requests:
            pid = req['passenger_id']
            if (s.shuttle_id, pid) in x:
                num_pass = req.get('num_passengers', 1)
                terms.append(num_pass * x[(s.shuttle_id, pid)])
        if terms:
            prob += pulp.lpSum(terms) <= s.capacity

    # 5) Solve
    prob.solve(pulp.PULP_CBC_CMD(msg=0))

    # 6) Extract solution
    assignment = {}
    for key, var in x.items():
        val = pulp.value(var)
        assignment[key] = 1 if val > 0.5 else 0

    # 7) Build unassigned reasons
    unassigned_reasons = {}
    for req in horizon_requests:
        pid = req['passenger_id']
        if pid not in feasible_passengers:
            unassigned_reasons[pid] = "No feasible shuttle in this suburb."
            continue
        relevant_vars = [(shid, pid) for (shid, p) in x.keys() if p == pid]
        total_assigned = sum(assignment.get((shid, pid), 0) for (shid, pid) in relevant_vars)
        if total_assigned == 0:
            unassigned_reasons[pid] = "Solver assigned x=0 for all feasible shuttles."

    return assignment, unassigned_reasons



##################################################################################################
# ROLLING-HORIZON
##################################################################################################

def rolling_horizon_for_legs(
    shuttles,
    leg_list,
    G_road,
    horizon_length,
    simulation_results=None,
    leg_type="first_mile"
):
    """
    A rolling-horizon approach that continues adding horizon windows
    until we either:
      - Serve all passengers, OR
      - We detect no progress (none got assigned) in a new horizon,
        meaning the solver won't serve them at all.
    """

    if not leg_list:
        return {}

    # Sort by start_time
    leg_list.sort(key=lambda x: x['start_time'])

    finish_times = {}
    current_time = 0.0
    
    idx = 0
    unassigned_leg = []

    horizon_index = 0
    assigned_any_in_last_horizon = True

    while True:
        if not assigned_any_in_last_horizon and unassigned_leg:
            print("\n[DEBUG] No progress in the last horizon => stopping.")
            break
        if not unassigned_leg and idx >= len(leg_list):
            print("\n[DEBUG] All passengers served or no new arrivals => done.")
            break

        horizon_index += 1
        horizon_end = current_time + horizon_length
        print(f"\n[DEBUG] Starting horizon #{horizon_index} => t=[{current_time}, {horizon_end}]")

        # 1) Add newly arrived legs to unassigned_leg (start_time < horizon_end)
        newly_added = 0
        while idx < len(leg_list) and leg_list[idx]['start_time'] < horizon_end:
            unassigned_leg.append(leg_list[idx])
            idx += 1
            newly_added += 1
        if newly_added:
            print(f"[DEBUG] Added {newly_added} newly arrived requests to unassigned_leg.")

        if not unassigned_leg:
            print("[DEBUG] No unassigned passengers => advancing horizon.")
            current_time = horizon_end
            continue

        # 2) Solve MILP
        assignment, unassigned_reasons = solve_shuttle_assignment_milp(shuttles, unassigned_leg)

        # 3) Build assigned map
        assigned_map = defaultdict(list)
        assigned_pids = set()
        for req in unassigned_leg:
            pid = req['passenger_id']
            for s in shuttles:
                if assignment.get((s.shuttle_id, pid), 0) == 1:
                    assigned_map[(s.shuttle_id, req['end_node'])].append(req)
                    assigned_pids.add(pid)

        # 4) Print unassigned for debug
        horizon_unassigned = [r for r in unassigned_leg if r['passenger_id'] not in assigned_pids]
        if horizon_unassigned:
            print(f"\n[DEBUG] Horizon #{horizon_index} did NOT assign:")
            for r in horizon_unassigned:
                pid = r['passenger_id']
                reason = unassigned_reasons.get(pid, "Solver assigned x=0 for all feasible shuttles.")
                print(f"   PID {pid} => {reason}")

        # 5) Perform pickups for assigned
        assigned_any_in_this_horizon = False
        for (sh_id, drop_node), req_list in assigned_map.items():
            if req_list:
                assigned_any_in_this_horizon = True

            shuttle_obj = next((s for s in shuttles if s.shuttle_id == sh_id), None)
            if shuttle_obj is None:
                continue

            earliest_req_time = min(rr['start_time'] for rr in req_list)
            updated_time = shuttle_charge_if_needed(shuttle_obj, earliest_req_time)

            final_arr = shuttle_multi_pickup_and_dropoff(
                shuttle=shuttle_obj,
                passengers=req_list,
                G_road=G_road,
                departure_time=updated_time,
                leg_type=leg_type,
                simulation_results=simulation_results
            )
            for rr in req_list:
                finish_times[rr['passenger_id']] = final_arr

        # 6) Remove assigned from unassigned_leg
        new_unassigned = [r for r in unassigned_leg if r['passenger_id'] not in assigned_pids]
        unassigned_leg = new_unassigned

        assigned_any_in_last_horizon = assigned_any_in_this_horizon
        current_time = horizon_end

    return finish_times


def run_rolling_simulation(
    shuttles,
    passenger_legs,
    G_road,
    G_bus,
    stop_times_map_bus,
    horizon_length=30.0
):
    """
    Rolling horizon for first-mile legs, bus operation, last-mile legs.
    """
    simulation_results = {}
    passenger_leg_map = defaultdict(lambda: {'leg1': None, 'leg2': None, 'leg3': None})

    # build passenger_leg_map
    for leg in passenger_legs:
        pid = leg['passenger_id']
        if leg['mode'] == 'shuttle' and leg['start_time'] != float('inf'):
            passenger_leg_map[pid]['leg1'] = leg
        elif leg['mode'] == 'bus':
            passenger_leg_map[pid]['leg2'] = leg
        elif leg['mode'] == 'shuttle' and leg['start_time'] == float('inf'):
            passenger_leg_map[pid]['leg3'] = leg

        if pid not in simulation_results:
            simulation_results[pid] = {
                'passenger_id': pid,
                'success': None,
                'failure_reason': None,
                'first_shuttle_id': None,
                'first_leg_start_time': None,
                'first_leg_end_time': None,
                'bus_start_time': None,
                'bus_end_time': None,
                'last_shuttle_id': None,
                'last_leg_start_time': None,
                'last_leg_end_time': None,
                'final_arrival_time': None,
            }

    # 1) first-mile
    first_mile_legs = []
    for pid, legs_dict in passenger_leg_map.items():
        if legs_dict['leg1'] is not None:
            first_mile_legs.append(legs_dict['leg1'])

    first_mile_finish = rolling_horizon_for_legs(
        shuttles=shuttles,
        leg_list=first_mile_legs,
        G_road=G_road,
        horizon_length=horizon_length,
        simulation_results=simulation_results,
        leg_type="first_mile"
    )

    # 2) bus
    for pid, data in passenger_leg_map.items():
        if data['leg1'] is None or data['leg2'] is None:
            simulation_results[pid]['success'] = False
            simulation_results[pid]['failure_reason'] = 'No first mile or no bus leg'
            continue
        if pid not in first_mile_finish:
            simulation_results[pid]['success'] = False
            simulation_results[pid]['failure_reason'] = 'First mile never served'
            continue

        leg2 = data['leg2']
        start_node = leg2['start_node']
        first_end = first_mile_finish[pid]
        if first_end is None:
            simulation_results[pid]['success'] = False
            simulation_results[pid]['failure_reason'] = 'First mile never completed'
            continue

        # Attempt (board_time, bus_arr)
        bus_res = earliest_arrival_time(
            G=G_bus,
            stop_times_map=stop_times_map_bus,
            origin_stop=start_node,
            origin_time=first_end,
            dest_stop=leg2['end_node']
        )
        if bus_res is None:
            simulation_results[pid]['success'] = False
            simulation_results[pid]['failure_reason'] = 'No bus available'
            continue

        board_time, bus_arr = bus_res
        simulation_results[pid]['bus_start_time'] = board_time
        simulation_results[pid]['bus_end_time'] = bus_arr

        # Now the third leg can start after the bus arrival
        if data['leg3'] is not None:
            data['leg3']['start_time'] = bus_arr

    # 3) last-mile
    last_mile_legs = []
    for pid, data in passenger_leg_map.items():
        if data['leg3'] is not None and simulation_results[pid]['failure_reason'] is None:
            bus_end = simulation_results[pid]['bus_end_time']
            if bus_end is not None:
                last_mile_legs.append(data['leg3'])
            else:
                simulation_results[pid]['success'] = False
                simulation_results[pid]['failure_reason'] = 'Bus leg never completed'

    last_mile_finish = rolling_horizon_for_legs(
        shuttles=shuttles,
        leg_list=last_mile_legs,
        G_road=G_road,
        horizon_length=horizon_length,
        simulation_results=simulation_results,
        leg_type="last_mile"
    )

    # Record final arrival time
    for pid in last_mile_finish:
        simulation_results[pid]['last_leg_end_time'] = last_mile_finish[pid]
        simulation_results[pid]['final_arrival_time'] = last_mile_finish[pid]
        if simulation_results[pid]['success'] is None:
            simulation_results[pid]['success'] = True

    # For any passenger who didn't need a third leg or wasn't included:
    for pid, rec in simulation_results.items():
        # If success is still None, we can finalize them as True if they reached the bus end
        if rec['success'] is None:
            rec['success'] = True
            rec['final_arrival_time'] = rec['bus_end_time']

    # finalize
    columns = [
        'passenger_id', 'success', 'failure_reason',
        'first_shuttle_id', 'first_leg_start_time', 'first_leg_end_time',
        'bus_start_time', 'bus_end_time',
        'last_shuttle_id', 'last_leg_start_time', 'last_leg_end_time',
        'final_arrival_time'
    ]
    rows = []
    for pid in sorted(simulation_results.keys()):
        row = simulation_results[pid]
        rows.append([row.get(col) for col in columns])

    df_sim = pd.DataFrame(rows, columns=columns)
    # Save, ensuring time columns are integers
    save_csv_safely(
        df_sim, 
        "simulation_results.csv", 
        time_columns=[
            'first_leg_start_time','first_leg_end_time',
            'bus_start_time','bus_end_time',
            'last_leg_start_time','last_leg_end_time',
            'final_arrival_time'
        ]
    )

    df_ops = pd.DataFrame(shuttle_operations_log)
    save_csv_safely(df_ops, "shuttle_operations.csv", time_columns=['time'])

    return df_sim.to_dict(orient='records')


##################################################################################################
# PASSENGER FILTERS
##################################################################################################

def filter_passengers_no_bus_available(passenger_df, G_bus, stop_times_map_bus):
    """
    Remove passengers if there is no feasible bus route from origin_stop->destination_stop.
    Returns a filtered DataFrame.
    """
    drop_indices = []
    for idx, row in passenger_df.iterrows():
        origin_stop = row['origin_stop']
        destination_stop = row['destination_stop']

        bus_res = earliest_arrival_time(
            G=G_bus,
            stop_times_map=stop_times_map_bus,
            origin_stop=origin_stop,
            origin_time=row.unix_timestamp,
            dest_stop=destination_stop
        )
        if bus_res is None:
            drop_indices.append(idx)

    filtered_df = passenger_df.drop(index=drop_indices).reset_index(drop=True)
    print(f"Filtered out {len(drop_indices)} passengers with no bus available.")
    return filtered_df


##################################################################################################
# SAFE CSV SAVING
##################################################################################################

def save_csv_safely(df: pd.DataFrame, filename: str, max_retries=5, wait_seconds=2, time_columns=None):
    """
    Attempt to save a DataFrame to CSV, retrying if a PermissionError occurs (file open in Excel).
    Also converts certain columns to integers (if they exist) before saving.
    """
    if time_columns is None:
        time_columns = []
    # Convert relevant columns to integer if possible
    for col in time_columns:
        if col in df.columns and pd.api.types.is_numeric_dtype(df[col]):
            df[col] = df[col].fillna(0).astype(int)

    for attempt in range(max_retries):
        try:
            df.to_csv(filename, index=False)
            print(f"Saved {filename}")
            return
        except PermissionError:
            print(f"[Attempt {attempt+1}] Permission denied for '{filename}'. Retrying in {wait_seconds}s...")
            time.sleep(wait_seconds)
    print(f"Could not save {filename} after {max_retries} attempts.")


##################################################################################################
# MAIN
##################################################################################################

def main(date="2024-05-01"):
    """
    Steps:
      1) Build G_road
      2) Build time-expanded G_bus
      3) Filter out passengers who are outside our working suburbs
      4) Filter out passengers for whom no bus route is available
      5) Rolling-horizon: first-mile, bus, last-mile
    """
    # 1) Read road network
    node_file_path = "../02result/road_network/node.json"
    edge_file_path = "../02result/road_network/edge.json"
    with open(node_file_path, "r") as node_file:
        nodes = json.load(node_file)
    with open(edge_file_path, "r") as edge_file:
        edges = json.load(edge_file)

    suburb_file_path = "../02result/road_network/suburb.json"
    gdf = gpd.read_file(suburb_file_path)
    gdf = gdf[['ID','DIVISION_N','geometry']]
    suburbs = gdf.ID.to_list()  # This is our set of valid suburbs

    # Build G_road
    df_nodes = pd.DataFrame(nodes, columns=['lat', 'lon'])
    gdf_nodes = gpd.GeoDataFrame(
        df_nodes,
        geometry=[Point(lon, lat) for lat, lon in nodes],
        crs="EPSG:4326"
    )

    G_road = nx.Graph()
    for (lat, lon) in nodes:
        G_road.add_node((lat, lon))
    for (lat1, lon1), (lat2, lon2) in edges:
        G_road.add_edge((lat1, lon1), (lat2, lon2))

    # map nodes to suburbs
    nodes_in_suburb = gpd.sjoin(gdf_nodes, gdf[['ID','geometry']], how='left', predicate='within')
    suburb_to_nodes = {}
    for s_id in suburbs:
        suburb_to_nodes[s_id] = nodes_in_suburb[nodes_in_suburb['ID'] == s_id]

    # pick depot node for each suburb
    suburb_depots = {}
    for idx, row in gdf.iterrows():
        s_id = row['ID']
        centroid = row.geometry.centroid
        sub_nodes = suburb_to_nodes[s_id]
        if not sub_nodes.empty:
            sub_nodes['distance'] = sub_nodes.geometry.distance(centroid)
            closest_node = sub_nodes.loc[sub_nodes['distance'].idxmin()]
            lat, lon = closest_node.geometry.y, closest_node.geometry.x
            suburb_depots[s_id] = (lat, lon)
        else:
            suburb_depots[s_id] = None

    # create shuttles
    shuttles = []
    for i, s_id in enumerate(suburbs):
        if suburb_depots[s_id] is None:
            # no valid depot
            continue
        shuttle = Shuttle(
            shuttle_id=i,
            suburb_id=s_id,
            depot_node=suburb_depots[s_id],
            capacity=60
        )
        shuttles.append(shuttle)

    # 2) Build time-expanded bus graph
    bus_data = pd.read_csv("../02result/bus_data/agg_bus_data.csv").drop(columns=['Unnamed: 0'])
    bus_data['stops'] = bus_data['stops'].apply(ast.literal_eval)
    bus_data['arrival_times'] = bus_data['arrival_times'].apply(ast.literal_eval)
    bus_data['departure_times'] = bus_data['departure_times'].apply(ast.literal_eval)
    G_bus, stop_times_map_bus = build_time_expanded_graph(bus_data)

    # 3) Prepare passenger data
    passenger_data = pd.read_csv(f"../02result/passenger_demand/filtered_data_{date}.csv")
    
    passenger_data["real_origin_node"] = passenger_data["real_origin_node"].apply(parse_coord_str)
    passenger_data["origin_stop_node"] = passenger_data["origin_stop_node"].apply(parse_coord_str)
    passenger_data["real_destination_node"] = passenger_data["real_destination_node"].apply(parse_coord_str)
    passenger_data["destination_stop_node"] = passenger_data["destination_stop_node"].apply(parse_coord_str)
    

    # 4) Create trip legs
    all_legs = []
    for _, row in passenger_data.iterrows():
        legs = create_passenger_trip_legs(row)
        all_legs.extend(legs)
    all_legs.sort(key=lambda leg: leg.get('start_time', float('inf')))

    # 5) run rolling simulation
    results = run_rolling_simulation(
        shuttles=shuttles,
        passenger_legs=all_legs,
        G_road=G_road,
        G_bus=G_bus,
        stop_times_map_bus=stop_times_map_bus,
        horizon_length=1.0
    )

    # final
    for r in results:
        pid = r['passenger_id']
        if not r['success']:
            print(f"Passenger {pid} => FAILED: {r['failure_reason']}")
        else:
            print(f"Passenger {pid} => SUCCESS, final arrival={r['final_arrival_time']}")

if __name__ == "__main__":
    main()
