In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import re

def extract_waiting_times(base_dir, case_labels, num_trials=50):
    data = {
        "Case": [],
        "Waiting Time": [],
        "Kappa": [],
        "Rate": [],
        "Density": [],
        "Stop Spacing": [],
        "Status": [],
    }

    for case_dir in os.listdir(base_dir):
        case_path = os.path.join(base_dir, case_dir)
        if not os.path.isdir(case_path) or not case_dir.startswith("Case_"):
            continue

        match = re.match(r'Case_([A-D])_StopToStop_(\d+)_Evenly_Spaced_Stops', case_dir)
        if not match:
            continue

        case_index = ord(match.group(1)) - ord('A')
        stop_spacing = int(match.group(2))
        if case_index >= len(case_labels):
            continue
        label = case_labels[case_index]

        for kappa_dir in os.listdir(case_path):
            if not kappa_dir.startswith('Kappa_'):
                continue
            kappa_path = os.path.join(case_path, kappa_dir)
            kappa_value = kappa_dir.split('_')[1]

            for density_dir in os.listdir(kappa_path):
                if not density_dir.startswith("Density_"):
                    continue
                density_path = os.path.join(kappa_path, density_dir)
                density_value = density_dir.split('_')[1]

                # --- Identify Low/High rate labels ---
                rate_dirs = [d for d in os.listdir(density_path) if d.startswith('PassengerRate_')]
                try:
                    rate_values = sorted([float(d.split('_')[1]) for d in rate_dirs])
                except (IndexError, ValueError):
                    continue

                if len(rate_values) < 2:
                    continue

                low_val = min(rate_values)
                high_val = max(rate_values)
                rate_label_map = {}
                for d in rate_dirs:
                    try:
                        val = float(d.split('_')[1])
                        rate_label_map[d.split('_')[1]] = "Low" if val == low_val else "High"
                    except (IndexError, ValueError):
                        continue

                for rate_dir in rate_dirs:
                    rate_value = rate_dir.split('_')[1]
                    rate_label = rate_label_map.get(rate_value, None)
                    if rate_label is None:
                        continue

                    passenger_data_path = os.path.join(density_path, rate_dir, 'PassengerData')
                    if not os.path.exists(passenger_data_path):
                        continue

                    for trial in range(1, num_trials + 1):
                        filename = f'Trial_{trial}_D{density_value}_K{kappa_value}_R{rate_value}_S{stop_spacing}.csv'
                        file_path = os.path.join(passenger_data_path, filename)

                        if os.path.exists(file_path):
                            df = pd.read_csv(file_path)
                            if {'Spawning Time', 'Waiting Time', 'Riding Status'}.issubset(df.columns):
                                df = df[df['Spawning Time'] > 6999]
                                for _, row in df.iterrows():
                                    if pd.notna(row['Waiting Time']):
                                        status = "Still Waiting" if row['Riding Status'] == "Waiting" else "Completed Waiting"
                                        data["Case"].append(label)
                                        data["Waiting Time"].append(row['Waiting Time'])
                                        data["Kappa"].append(kappa_value)
                                        data["Rate"].append(rate_label)
                                        data["Density"].append(density_value)
                                        data["Stop Spacing"].append(stop_spacing)
                                        data["Status"].append(status)

    return pd.DataFrame(data)


def plot_waiting_time_distributions(df, case_labels):
    sns.set_theme(style="ticks", context="talk", font_scale=1.2)

    unique_densities = sorted(df['Density'].unique(), key=lambda x: float(x))
    unique_kappas = sorted(df['Kappa'].unique(), key=lambda x: float(x))
    unique_rates = ["Low", "High"]

    for status in ["Completed Waiting", "Still Waiting"]:
        status_df = df[df["Status"] == status]

        for density in unique_densities:
            for kappa in unique_kappas:
                for rate in unique_rates:
                    subset = status_df[
                        (status_df["Density"] == density) &
                        (status_df["Kappa"] == kappa) &
                        (status_df["Rate"] == rate)
                    ]

                    if subset.empty:
                        continue

                    plt.figure(figsize=(10, 6))
                    sns.boxplot(
                        data=subset,
                        x="Stop Spacing",
                        y="Waiting Time",
                        hue="Case",
                        hue_order=case_labels,
                        palette=["#1F77B4", "#FF7F0E", "#2CA02C", "#D62728"],
                        showfliers=False
                    )

                    plt.xlabel(r"$\Delta x_{stop}$", fontsize=40)
                    plt.ylabel(r"$t_{wait}$", fontsize=42)
                    plt.xticks(fontsize=35)
                    plt.yticks(fontsize=30)
                    plt.grid(True, linestyle='--', alpha=0.6)
                    plt.legend(title="", fontsize=22)
                    sns.despine()
                    for spine in plt.gca().spines.values():
                        spine.set_visible(True)
                    plt.tight_layout()

                    filename = (
                        f"WaitingTime_{status}_Density_{density}"
                        f"_Kappa_{kappa}_Rate_{rate}.png"
                    )
                    plt.savefig(filename, dpi=200)
                    plt.close()


# =========================
# 🔧 Parameters
# =========================
base_dir = os.path.abspath(os.path.join(os.getcwd(), '../../..', 'With Designated Stops Results'))
case_labels = ["OO", r"$O\tilde{T}$", r"$\tilde{J}O$", r"$\tilde{J}\tilde{T}$"]

# Run extraction and plotting
df_wait = extract_waiting_times(base_dir, case_labels, num_trials=50)
plot_waiting_time_distributions(df_wait, case_labels)