In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import re  # for regex parsing of stop-to-stop distance

def extract_truck_temporal_speed(base_dir, target_densities, case_labels, num_trials=50):
    cases = sorted([
        d for d in os.listdir(base_dir)
        if os.path.isdir(os.path.join(base_dir, d)) and d.startswith('March26(Vehicle and Timestep)_Case_')
    ])

    if len(cases) != len(case_labels):
        raise ValueError("Number of cases and case labels must match!")

    all_kappas = set()
    all_rates = set()
    truck_speed_data = []

    for case, label in zip(cases, case_labels):
        print(f"{label}: {case}")
        case_path = os.path.join(base_dir, case)

        for kappa_dir in os.listdir(case_path):
            if not kappa_dir.startswith('Kappa_'):
                continue
            kappa_value = kappa_dir.split('_')[1]
            all_kappas.add(kappa_value)
            kappa_path = os.path.join(case_path, kappa_dir)

            density = target_densities.get((label, kappa_value))
            if density is None:
                print(f"⚠️ Skipping: No density for ({label}, {kappa_value})")
                continue

            density_path = os.path.join(kappa_path, f'Density_{density}')
            if not os.path.exists(density_path):
                print(f"⚠️ Skipping: No directory {density_path}")
                continue

            for rate_dir in os.listdir(density_path):
                if not rate_dir.startswith('PassengerRate_'):
                    continue
                rate_value = rate_dir.split('_')[1]
                all_rates.add(rate_value)
                vehicle_data_path = os.path.join(density_path, rate_dir, 'VehicleData')
                if not os.path.exists(vehicle_data_path):
                    continue

                for trial in range(1, num_trials + 1):
                    prefix = f'Trial_{trial}_D{density}_K{kappa_value}_R{rate_value}'
                    matched_files = [
                        f for f in os.listdir(vehicle_data_path)
                        if f.startswith(prefix) and f.endswith('.csv')
                    ]

                    if not matched_files:
                        continue

                    file_path = os.path.join(vehicle_data_path, matched_files[0])
                    df = pd.read_csv(file_path)

                    if 'Vehicle Type' in df.columns and 'Mean Speed Across Time' in df.columns:
                        trucks = df[df['Vehicle Type'] == 'truck']

                        # Extract Stop-To-Stop Distance if present
                        stop_dist_match = re.search(r'_S(\d+)', matched_files[0])
                        stop_distance = int(stop_dist_match.group(1)) if stop_dist_match else None

                        for speed in trucks['Mean Speed Across Time'].dropna():
                            truck_speed_data.append({
                                "Case": label,
                                "Truck Speed": speed,
                                "Kappa": kappa_value,
                                "Rate": rate_value,
                                "StopDist": stop_distance
                            })

    df_truck = pd.DataFrame(truck_speed_data)
    return df_truck, sorted(all_kappas), sorted(all_rates)


def plot_truck_speed_boxplots(df_truck, case_labels, kappa_value, rate_value):
    plt.figure(figsize=(22, 12))
    sns.set_theme(style="ticks", context="talk", font_scale=1.2)

    df = df_truck[(df_truck['Kappa'] == kappa_value) & (df_truck['Rate'] == rate_value)].copy()

    if df.empty:
        print(f"Skipping: No truck speed data for Kappa {kappa_value}, Rate {rate_value}")
        return

    sns.boxplot(
        data=df,
        x="Truck Speed",
        y="Case",
        color="#FFBF00",
        showfliers=False,
        linewidth=1.5,
        orient="h"
    )

    plt.xticks(fontsize=55)
    plt.yticks(fontsize=85)
    plt.xlabel(r"$<v_{truck}>_{temporal}$", fontsize=55)
    plt.ylabel("")
    plt.grid(True, linestyle='--', alpha=0.5)
    sns.despine()
    plt.gca().spines['top'].set_visible(True)
    plt.gca().spines['right'].set_visible(True)
    plt.gca().spines['left'].set_visible(True)
    plt.gca().spines['bottom'].set_visible(True)
    plt.tight_layout()

    outname = f"Truck_Temporal_Mean_Speed_Kappa_{kappa_value}_Rate_{rate_value}_targetted_0.2.png"
    plt.savefig(outname, dpi=200)
    plt.show()
    plt.close()
    print(f"✅ Saved: {outname}")

# ============================
# 🔧 Set your parameters below
# ============================
base_dir = os.path.abspath(os.path.join(os.getcwd(), '../../..', 'Load Anywhere Output'))
case_labels = ["OO", r"$O\tilde{T}$", r"$\tilde{J}O$", r"$\tilde{J}\tilde{T}$"]

target_densities = {
    ("OO", "0"): 0.2,
    ("OO", "0.2"): 0.2,
    ("OO", "0.4"): 0.2,
    
    (r"$O\tilde{T}$", "0"): 0.2,
    (r"$O\tilde{T}$", "0.2"): 0.2,
    (r"$O\tilde{T}$", "0.4"): 0.2,
    
    (r"$\tilde{J}O$", "0"): 0.2,
    (r"$\tilde{J}O$", "0.2"): 0.2,
    (r"$\tilde{J}O$", "0.4"): 0.2,
    
    (r"$\tilde{J}\tilde{T}$", "0"): 0.2,
    (r"$\tilde{J}\tilde{T}$", "0.2"): 0.2,
    (r"$\tilde{J}\tilde{T}$", "0.4"): 0.2,
}

# ✅ Extract truck temporal mean speeds
df_truck, all_kappas, all_rates = extract_truck_temporal_speed(
    base_dir, target_densities, case_labels
)

# 🔁 Plot boxplots for all Kappa × Rate combinations
for kappa in all_kappas:
    for rate in all_rates:
        plot_truck_speed_boxplots(df_truck, case_labels, kappa, rate)