In [None]:
%%pyspark
%pip install ortools

In [None]:
from ortools.constraint_solver import routing_enums_pb2, pywrapcp
from pyspark.sql import functions as F, types as T
from pyspark.sql import Row
from concurrent.futures import ThreadPoolExecutor
from pyspark.sql.types import StructType, StructField, StringType, DoubleType, IntegerType
from datetime import datetime, timedelta
from pyspark.sql.functions import date_format
import math

In [None]:
def minutes_to_datetime(week_start_date, minutes):
    start_week = datetime.strptime(week_start_date, "%Y-%m-%d")
    return start_week + timedelta(minutes=minutes)


def to_minutes(tstr):
    hh, mm = map(int, tstr.split(':'))
    return hh*60 + mm

In [None]:
def build_week_input(forecast_sdf, site_sdf, travel_sdf, week_start_date, spark):
    """
    Prepare weekly input data for OR-Tools.
    Only collects small weekly subset to driver.
    """
    # Filter forecast for the week
    week_sdf = forecast_sdf.filter(F.col("week_start") == week_start_date)
    if week_sdf.rdd.isEmpty():
        print(f"No forecast rows for week {week_start_date}")
        return None, None

    # Aggregate demand per site
    demand_sdf = (
        week_sdf.groupBy("site_id")
        .agg(F.sum("forecast_units").alias("forecast_units"))
    )

    # Ensure depot exists
    if site_sdf.filter(F.col("site_id") == "PORT0").count() == 0:
        depot_row = [("PORT0", 0.0, 0.0, "00:00", "23:59", 0, 0)]
        depot_sdf = spark.createDataFrame(
            depot_row,
            ["site_id","lat","lon","open_time","close_time","service_time_minutes","max_visit_volume_units"]
        )
        site_sdf = depot_sdf.unionByName(site_sdf)

    # Join site specs with demand
    df_nodes_sdf = site_sdf.join(demand_sdf, on="site_id", how="right").fillna(0)

    # Convert to a list of Rows for OR-Tools (small data)
    df_nodes_local = df_nodes_sdf.collect()

    # Travel times as a local dict
    travel_dict = {
        (row["from_site"], row["to_site"]): float(row["travel_minutes"])
        for row in travel_sdf.collect()
    }

    return df_nodes_local, travel_dict


In [None]:


def create_data_model(df_nodes_local, travel_dict, barge_sdf):
    """
    Converts local weekly data into OR-Tools input dictionary.
    df_nodes_local: list of Row
    """
    depot_id = "PORT0"
    nodes = [depot_id] + [row["site_id"] for row in df_nodes_local]

    # Time matrix
    n = len(nodes)
    time_matrix = [[0]*n for _ in range(n)]
    node_index = {node: i for i, node in enumerate(nodes)}

    for i, ni in enumerate(nodes):
        for j, nj in enumerate(nodes):
            if i != j:
                time_matrix[i][j] = travel_dict.get((ni, nj), 9999)

    # Demands & service times
    demands = [0] + [float(row["forecast_units"]) for row in df_nodes_local]
    service_times = [0]
    barge_row = barge_sdf.collect()[0]  # use first barge for loading rate

    def to_minutes(tstr):
        hh, mm = map(int, tstr.split(":"))
        return hh*60 + mm

    windows = [(0, 24*60)]
    for row in df_nodes_local:
        qty = float(row["forecast_units"])
        site_min = float(row["service_time_minutes"]) if row["service_time_minutes"] else 30
        service_times.append(max(site_min, math.ceil(qty / float(barge_row["avg_loading_rate_units_per_min"]))))
        windows.append((to_minutes(row["open_time"]), to_minutes(row["close_time"])))

    # Vehicles
    barge_rows = barge_sdf.collect()  # collect once
    vehicle_capacities = [int(r['total_capacity_units']) for r in barge_rows]
    vehicle_time_windows = [
        (to_minutes(r["working_hours_start"]), to_minutes(r["working_hours_end"]))
        for r in barge_sdf.collect()
    ]
    barge_ids = [r["barge_id"] for r in barge_rows]
    num_vehicles = len(vehicle_capacities)

    return {
        "time_matrix": time_matrix,
        "demands": demands,
        "service_times": service_times,
        "time_windows": windows,
        "vehicle_capacities": vehicle_capacities,
        "vehicle_time_windows": vehicle_time_windows,
        "num_vehicles": num_vehicles,
        "depot": 0,
        "nodes": nodes,
        "barge_ids": barge_ids
    }


In [None]:
def solve_cvrptw(data, week_start_date):
    # Initialize OR-Tools manager and routing model
    manager = pywrapcp.RoutingIndexManager(
        len(data['time_matrix']), 
        data['num_vehicles'], 
        data['depot']
    )
    routing = pywrapcp.RoutingModel(manager)

    # --- Transit callback (travel + service time) ---
    def time_callback(from_idx, to_idx):
        from_node = manager.IndexToNode(from_idx)
        to_node = manager.IndexToNode(to_idx)
        return int(data['time_matrix'][from_node][to_node]) + int(data['service_times'][from_node])

    transit_idx = routing.RegisterTransitCallback(time_callback)
    routing.SetArcCostEvaluatorOfAllVehicles(transit_idx)

    # --- Time dimension ---
    horizon = 24 * 60 * 7  # one week in minutes
    routing.AddDimension(
        transit_idx,
        0,          # no slack
        horizon,    # max cumulative time
        False,      # do not force start at zero
        'Time'
    )
    time_dim = routing.GetDimensionOrDie('Time')

    # Set node time windows
    for idx, (w0, w1) in enumerate(data['time_windows']):
        node_index = manager.NodeToIndex(idx) 
        w0 = max(0, int(round(w0)))
        w1 = max(w0 + 1, int(round(w1)))  # ensure w1 > w0
        time_dim.CumulVar(node_index).SetRange(w0, w1)

    # --- Demand / capacity dimension ---
    def demand_callback(from_idx):
        return int(data['demands'][manager.IndexToNode(from_idx)])

    demand_idx = routing.RegisterUnaryTransitCallback(demand_callback)
    routing.AddDimensionWithVehicleCapacity(
        demand_idx,
        slack_max=0,  # slack
        vehicle_capacities=data['vehicle_capacities'],
        fix_start_cumul_to_zero=True,  # fix start cumulative to zero
        name='Capacity'
    )


    # --- Vehicle time windows ---
    for vid in range(data['num_vehicles']):
        start_var = time_dim.CumulVar(routing.Start(vid))
        end_var = time_dim.CumulVar(routing.End(vid))
        w0, w1 = data['vehicle_time_windows'][vid]
        start_var.SetRange(w0, w1)
        end_var.SetRange(w0, w1)

    # --- Solver parameters ---
    params = pywrapcp.DefaultRoutingSearchParameters()
    params.first_solution_strategy = routing_enums_pb2.FirstSolutionStrategy.SAVINGS
    params.local_search_metaheuristic = routing_enums_pb2.LocalSearchMetaheuristic.GUIDED_LOCAL_SEARCH
    params.time_limit.seconds = 60
    params.log_search = False

    # --- Solve ---
    solution = routing.SolveWithParameters(params)
    if not solution:
        print(f"No solution for week {week_start_date}")
        return None

    # --- Extract routes ---
    route = {barge_id: [] for barge_id in data['barge_ids']}
    for vid in range(data['num_vehicles']):
        idx = routing.Start(vid)
        order = 0
        while not routing.IsEnd(idx):
            node = manager.IndexToNode(idx)
            if node != data['depot']:
                arrival = solution.Min(time_dim.CumulVar(idx))
                departure = solution.Max(time_dim.CumulVar(idx))
                route[data['barge_ids'][vid]].append({
                    'order': order,
                    'site_id': data['nodes'][node],
                    'qty': data['demands'][node],
                    'arrival_min': arrival,
                    'departure_min': departure,
                    'arrival_dt': minutes_to_datetime(week_start_date, arrival),
                    'departure_dt': minutes_to_datetime(week_start_date, departure)
                })
                order += 1
            idx = solution.Value(routing.NextVar(idx))

    return route



In [None]:
# Read forecasts from Lakehouse table
forecast_sdf = spark.read.table("forecastweekly")

# Read site specs / travel / barge from CSV or Lakehouse tables
site_specs_sdf = spark.read.option("header", True).csv("Files/Sites/site_specs.csv")
travel_sdf = spark.read.option("header", True).csv("Files/Sites/travel_times.csv")
barge_sdf = spark.read.option("header", True).csv("Files/Barges/barge_specs.csv")



In [None]:
# Cast numeric columns

sdf_map = {
    "forecast_sdf": forecast_sdf,
    "site_specs_sdf": site_specs_sdf,
    "travel_sdf": travel_sdf,
    "barge_sdf": barge_sdf
}

numeric_cols = ["forecast_units", "total_capacity_units", "avg_loading_rate_units_per_min", "service_time_minutes", "travel_minutes"]

for name, sdf in sdf_map.items():
    for col in numeric_cols:
        if col in sdf.columns:
            sdf_map[name] = sdf.withColumn(col, F.col(col).cast("double"))
    # Re-assign back to original variables
    globals()[name] = sdf_map[name]

In [None]:
num_weeks_to_run = 3  # Adjust as needed

weeks_list = (forecast_sdf
              .select(date_format("week_start", "yyyy-MM-dd").alias("week_start"))
              .distinct()
              .orderBy("week_start")
              .rdd.flatMap(lambda x: x)
              .collect())
weeks_list = weeks_list[:num_weeks_to_run]

print(f"Processing weeks: {weeks_list}")


In [None]:
def process_week(week):
    df_nodes, travel_dict = build_week_input(forecast_sdf, site_specs_sdf, travel_sdf, week, spark)
    if df_nodes is None:
        return week, None
    data_model = create_data_model(df_nodes, travel_dict, barge_sdf)
    return week, solve_cvrptw(data_model, week)



In [None]:
#from concurrent.futures import ProcessPoolExecutor

routes_all_weeks = {}
total = len(weeks_list)

for i, week in enumerate(weeks_list, 1):
    print(f"[{i}/{total}] Processing week {week}...")
    week, routes = process_week(week)
    routes_all_weeks[week] = routes

print("All weeks processed.")

#with ProcessPoolExecutor(max_workers=max_workers) as executor:
#    for week, routes in executor.map(process_week, weeks_list):
#        routes_all_weeks[week] = routes

In [None]:
for week, week_routes in routes_all_weeks.items():
    if not week_routes:
        continue

    rows = []
    for barge_id, stops in week_routes.items():
        for stop in stops:
            rows.append(Row(
                week_start=week,
                barge_id=barge_id,
                order=stop['order'],
                site_id=stop['site_id'],
                qty=stop['qty'],
                arrival_min=stop['arrival_min'],
                departure_min=stop['departure_min'],
                arrival_dt=stop['arrival_dt'],
                departure_dt=stop['departure_dt']
            ))

    week_sdf = spark.createDataFrame(rows)
    week_sdf.write.format("delta").mode("append").saveAsTable("weekly_routes")