In [None]:
import os
import re
import glob

import plotly.graph_objects as go
import polars as pl
import ipywidgets as widgets
from IPython.display import display

DATA_DIR = "./results/summary"

METRICS = {
    "Best-so-far (feasible) objective": {
        "mean": "run_best_so_far_mean",
        "lower": "best_ci_lower",
        "upper": "best_ci_upper",
        "ytitle": "Best-so-far objective (mean)",
        "title": "Best-so-far (feasible) objective per generation",
    },
    "Hypervolume": {
        "mean": "hv_mean",
        "lower": "hv_ci_lower",
        "upper": "hv_ci_upper",
        "ytitle": "Hypervolume (mean)",
        "title": "Hypervolume per generation",
    },
}

PATTERN = re.compile(
    r"^(?P<problem>.+?)_(?P<mode>.+?)_gen(?P<gen>\d+)_runs(?P<runs>\d+)_ct(?P<ct>.+?)_psize(?P<psize>\d+)_summary\.parquet$"
)


def list_summary_files(data_dir: str):
    files = sorted([os.path.basename(p) for p in glob.glob(os.path.join(data_dir, "*_summary.parquet"))])
    rows = []
    for f in files:
        m = PATTERN.match(f)
        if not m:
            continue
        d = m.groupdict()
        d["gen"] = int(d["gen"])
        d["runs"] = int(d["runs"])
        d["psize"] = int(d["psize"])
        d["file"] = f
        rows.append(d)
    return rows


rows = list_summary_files(DATA_DIR)
if not rows:
    raise ValueError(f"No summary files matched pattern in {DATA_DIR}")


def uniq_sorted(vals):
    return sorted(set(vals), key=lambda x: (x if isinstance(x, (int, float)) else str(x)))


def filter_rows(problem=None, mode=None, gen=None, runs=None, ct=None, psize=None):
    out = rows
    for k, v in [("problem", problem), ("mode", mode), ("gen", gen), ("runs", runs), ("ct", ct), ("psize", psize)]:
        if v is not None:
            out = [r for r in out if r[k] == v]
    return out


def find_file(sel):
    m = filter_rows(**sel)
    return m[0]["file"] if m else None


# ---------- plotting ----------
def add_fig(fig, data: pl.DataFrame, label: str, metric_key: str):
    spec = METRICS[metric_key]
    x = data["generation"].to_list()
    y = data[spec["mean"]].to_list()
    lo = data[spec["lower"]].to_list()
    up = data[spec["upper"]].to_list()

    fig.add_trace(go.Scatter(x=x, y=y, mode="lines", name=label))
    fig.add_trace(
        go.Scatter(
            x=x + x[::-1],
            y=up + lo[::-1],
            fill="toself",
            line={"width": 0},
            opacity=0.25,
            hoverinfo="skip",
            showlegend=False,
            name=f"{label} 95% CI",
        )
    )


def plot_selected(sel_a, sel_b, metric_key):
    file_a = find_file(sel_a)
    file_b = find_file(sel_b)

    if file_a is None or file_b is None:
        print("No file matches one of the selections.")
        print("A:", sel_a, "->", file_a)
        print("B:", sel_b, "->", file_b)
        return

    fig = go.Figure()

    for sel, fname in [(sel_a, file_a), (sel_b, file_b)]:
        label = f"{sel['problem']} ({sel['mode']})"
        data = pl.read_parquet(os.path.join(DATA_DIR, fname)).sort("generation")
        add_fig(fig, data, label, metric_key)

    fig.update_layout(
        title=METRICS[metric_key]["title"],
        xaxis_title="Generation",
        yaxis_title=METRICS[metric_key]["ytitle"],
    )
    fig.show()


# ---------- dependent selector UI ----------
def make_dependent_selector(prefix: str):
    w_problem = widgets.Dropdown(description=f"{prefix} problem:", layout=widgets.Layout(width="70%"))
    w_mode = widgets.Dropdown(description=f"{prefix} mode:", layout=widgets.Layout(width="70%"))
    w_gen = widgets.Dropdown(description=f"{prefix} gen:", layout=widgets.Layout(width="70%"))
    w_runs = widgets.Dropdown(description=f"{prefix} runs:", layout=widgets.Layout(width="70%"))
    w_ct = widgets.Dropdown(description=f"{prefix} ct:", layout=widgets.Layout(width="70%"))
    w_psize = widgets.Dropdown(description=f"{prefix} psize:", layout=widgets.Layout(width="70%"))

    def set_options(w, opts, prefer=None):
        opts = uniq_sorted(opts)
        w.options = opts
        if not opts:
            w.value = None
            return
        if prefer in opts:
            w.value = prefer
        else:
            # keep current if valid, else first
            w.value = w.value if w.value in opts else opts[0]

    def refresh_from_problem(_=None):
        p = w_problem.value
        r = filter_rows(problem=p)
        set_options(w_mode, [x["mode"] for x in r], prefer=w_mode.value)
        refresh_from_mode()

    def refresh_from_mode(_=None):
        p, m = w_problem.value, w_mode.value
        r = filter_rows(problem=p, mode=m)
        set_options(w_gen, [x["gen"] for x in r], prefer=w_gen.value)
        refresh_from_gen()

    def refresh_from_gen(_=None):
        p, m, g = w_problem.value, w_mode.value, w_gen.value
        r = filter_rows(problem=p, mode=m, gen=g)
        set_options(w_runs, [x["runs"] for x in r], prefer=w_runs.value)
        refresh_from_runs()

    def refresh_from_runs(_=None):
        p, m, g, rn = w_problem.value, w_mode.value, w_gen.value, w_runs.value
        r = filter_rows(problem=p, mode=m, gen=g, runs=rn)
        set_options(w_ct, [x["ct"] for x in r], prefer=w_ct.value)
        refresh_from_ct()

    def refresh_from_ct(_=None):
        p, m, g, rn, ct = w_problem.value, w_mode.value, w_gen.value, w_runs.value, w_ct.value
        r = filter_rows(problem=p, mode=m, gen=g, runs=rn, ct=ct)
        set_options(w_psize, [x["psize"] for x in r], prefer=w_psize.value)

    # wire callbacks (order matters: each callback refreshes downstream)
    w_problem.observe(refresh_from_problem, names="value")
    w_mode.observe(refresh_from_mode, names="value")
    w_gen.observe(refresh_from_gen, names="value")
    w_runs.observe(refresh_from_runs, names="value")
    w_ct.observe(refresh_from_ct, names="value")

    # initialize
    set_options(w_problem, [x["problem"] for x in rows])
    refresh_from_problem()

    def selection():
        return dict(
            problem=w_problem.value,
            mode=w_mode.value,
            gen=w_gen.value,
            runs=w_runs.value,
            ct=w_ct.value,
            psize=w_psize.value,
        )

    controls = [w_problem, w_mode, w_gen, w_runs, w_ct, w_psize]
    return controls, selection


metric_dd = widgets.Dropdown(
    options=list(METRICS.keys()),
    description="Metric:",
    value=list(METRICS.keys())[0],
    layout=widgets.Layout(width="70%"),
)

a_controls, a_sel = make_dependent_selector("A")
b_controls, b_sel = make_dependent_selector("B")


# Set B to a different mode if possible (within the same problem)
def set_b_default_mode():
    sel = b_sel()
    r = filter_rows(problem=sel["problem"])
    modes = uniq_sorted([x["mode"] for x in r])
    if len(modes) > 1:
        # pick a mode different from A if possible
        a_mode = a_sel()["mode"]
        cand = next((mm for mm in modes if mm != a_mode), modes[0])
        b_controls[1].value = cand  # w_mode


set_b_default_mode()

button = widgets.Button(description="Plot", button_style="primary")
out = widgets.Output()


def on_click_plot(_):
    with out:
        out.clear_output(wait=True)
        plot_selected(a_sel(), b_sel(), metric_dd.value)


button.on_click(on_click_plot)

display(
    widgets.VBox(
        [
            metric_dd,
            widgets.HTML("<b>Data 1 (A)</b>"),
            *a_controls,
            widgets.HTML("<b>Data 2 (B)</b>"),
            *b_controls,
            button,
            out,
        ]
    )
)

VBox(children=(Dropdown(description='Metric:', layout=Layout(width='70%'), options=('Best-so-far (feasible) obâ€¦