In [5]:
import os
import re
import glob

import plotly.graph_objects as go
import polars as pl
import ipywidgets as widgets
from IPython.display import display

# -----------------------------
# Config
# -----------------------------

DATA_DIR = "./results/data"  # adjust

# Make the pattern strict once you know the suffix.
# Example expects something like:
# branin_baseline_gen150_runs500_cthigh_psize12.parquet
PATTERN = re.compile(
    r"^(?P<problem>.+?)_"
    r"(?P<mode>baseline|relaxed|ranking)_"
    r"gen(?P<gen>\d+)_"
    r"runs(?P<runs>\d+)_"
    r"ct(?P<ctlevel>low|med|high)_"
    r"psize(?P<psize>\d+)"
    r"(?P<suffix>.*?)"
    r"(?P<ext>\.parquet|\.csv)$"
)

GEN_COL = "generation"
RUN_COL = "run"

# Duplicate signature:
# Objective-space degeneration (recommended default):
SIGNATURE_COLS = ["f_1_min"]

# If you want *exact clones* instead (stronger):
# SIGNATURE_COLS = ["x_1", "x_2", "f_1_min"]

ROUND_DECIMALS = 6  # rounding before comparing floats

METRICS = {
    "Duplicate ratio": {
        "mean": "dup_ratio_mean",
        "lower": "dup_ratio_ci_lower",
        "upper": "dup_ratio_ci_upper",
        "ytitle": "Duplicate ratio (mean across runs)",
        "title": "Degeneration (duplicates) per generation",
    },
    "Number of duplicates": {
        "mean": "num_dups_mean",
        "lower": "num_dups_ci_lower",
        "upper": "num_dups_ci_upper",
        "ytitle": "# duplicates (mean across runs)",
        "title": "Degeneration (absolute duplicates) per generation",
    },
    "Unique solutions": {
        "mean": "unique_mean",
        "lower": "unique_ci_lower",
        "upper": "unique_ci_upper",
        "ytitle": "# unique solutions (mean across runs)",
        "title": "Unique solutions per generation",
    },
}


# -----------------------------
# File discovery (same style)
# -----------------------------


def list_data_files(data_dir: str):
    files = sorted([os.path.basename(p) for p in glob.glob(os.path.join(data_dir, "*"))])
    rows = []
    for f in files:
        m = PATTERN.match(f)
        if not m:
            continue
        d = m.groupdict()
        d["gen"] = int(d["gen"])
        d["runs"] = int(d["runs"])
        d["psize"] = int(d["psize"])
        d["file"] = f
        rows.append(d)
    return rows


rows = list_data_files(DATA_DIR)
if not rows:
    raise ValueError(f"No files matched pattern in {DATA_DIR}. Adjust PATTERN/DATA_DIR.")


def uniq_sorted(vals):
    return sorted(set(vals), key=lambda x: (x if isinstance(x, (int, float)) else str(x)))


def filter_rows(problem=None, mode=None, gen=None, runs=None, ctlevel=None, psize=None):
    out = rows
    for k, v in [
        ("problem", problem),
        ("mode", mode),
        ("gen", gen),
        ("runs", runs),
        ("ctlevel", ctlevel),
        ("psize", psize),
    ]:
        if v is not None:
            out = [r for r in out if r[k] == v]
    return out


def find_file(sel):
    m = filter_rows(**sel)
    return m[0]["file"] if m else None


# -----------------------------
# Reading + degeneration computation
# -----------------------------


def _read_file(path: str) -> pl.DataFrame:
    if path.endswith(".csv"):
        return pl.read_csv(path)
    return pl.read_parquet(path)


def compute_duplicates_timeseries(df: pl.DataFrame) -> pl.DataFrame:
    # basic schema checks
    for col in [GEN_COL, RUN_COL] + SIGNATURE_COLS:
        if col not in df.columns:
            raise ValueError(f"Missing required column '{col}'. Columns: {df.columns}")

    # rounding for float-safe duplicate detection
    round_exprs = []
    for c in SIGNATURE_COLS:
        if df.schema[c] in (pl.Float32, pl.Float64):
            round_exprs.append(pl.col(c).round(ROUND_DECIMALS).alias(c))
    if round_exprs:
        df = df.with_columns(round_exprs)

    # Count occurrences of each signature inside each (run, generation)
    occ = df.group_by([RUN_COL, GEN_COL] + SIGNATURE_COLS).len().rename({"len": "occ"})

    # Reduce to per-(run, generation) metrics
    per_run_gen = (
        occ.group_by([RUN_COL, GEN_COL])
        .agg(
            [
                pl.len().alias("unique"),  # number of unique signatures
                pl.sum("occ").alias("popsize"),  # total solutions in that gen
                (pl.sum("occ") - pl.len()).alias("num_dups"),
                ((pl.sum("occ") - pl.len()).cast(pl.Float64) / pl.sum("occ").cast(pl.Float64)).alias("dup_ratio"),
            ]
        )
        .sort(GEN_COL)
    )

    # Now summarize across runs per generation: mean + 95% CI
    # CI: mean ± 1.96 * (std / sqrt(n))
    out = (
        per_run_gen.rename({GEN_COL: "generation"})
        .group_by("generation")
        .agg(
            [
                pl.count().alias("n_runs"),
                pl.mean("dup_ratio").alias("dup_ratio_mean"),
                pl.std("dup_ratio").alias("dup_ratio_std"),
                pl.mean("num_dups").cast(pl.Float64).alias("num_dups_mean"),
                pl.std("num_dups").cast(pl.Float64).alias("num_dups_std"),
                pl.mean("unique").cast(pl.Float64).alias("unique_mean"),
                pl.std("unique").cast(pl.Float64).alias("unique_std"),
            ]
        )
        .with_columns(
            [
                (pl.col("dup_ratio_std") / pl.col("n_runs").sqrt()).alias("dup_ratio_se"),
                (pl.col("num_dups_std") / pl.col("n_runs").sqrt()).alias("num_dups_se"),
                (pl.col("unique_std") / pl.col("n_runs").sqrt()).alias("unique_se"),
            ]
        )
        .with_columns(
            [
                (pl.col("dup_ratio_mean") - 1.96 * pl.col("dup_ratio_se")).alias("dup_ratio_ci_lower"),
                (pl.col("dup_ratio_mean") + 1.96 * pl.col("dup_ratio_se")).alias("dup_ratio_ci_upper"),
                (pl.col("num_dups_mean") - 1.96 * pl.col("num_dups_se")).alias("num_dups_ci_lower"),
                (pl.col("num_dups_mean") + 1.96 * pl.col("num_dups_se")).alias("num_dups_ci_upper"),
                (pl.col("unique_mean") - 1.96 * pl.col("unique_se")).alias("unique_ci_lower"),
                (pl.col("unique_mean") + 1.96 * pl.col("unique_se")).alias("unique_ci_upper"),
            ]
        )
        .select(
            [
                "generation",
                "dup_ratio_mean",
                "dup_ratio_ci_lower",
                "dup_ratio_ci_upper",
                "num_dups_mean",
                "num_dups_ci_lower",
                "num_dups_ci_upper",
                "unique_mean",
                "unique_ci_lower",
                "unique_ci_upper",
            ]
        )
        .sort("generation")
    )
    return out


# -----------------------------
# Plotting helpers (same as yours)
# -----------------------------


def add_fig(fig, data: pl.DataFrame, label: str, metric_key: str):
    spec = METRICS[metric_key]
    x = data["generation"].to_list()
    y = data[spec["mean"]].to_list()
    lo = data[spec["lower"]].to_list()
    up = data[spec["upper"]].to_list()

    fig.add_trace(go.Scatter(x=x, y=y, mode="lines", name=label))
    fig.add_trace(
        go.Scatter(
            x=x + x[::-1],
            y=up + lo[::-1],
            fill="toself",
            line={"width": 0},
            opacity=0.25,
            hoverinfo="skip",
            showlegend=False,
            name=f"{label} 95% CI",
        )
    )


def load_and_compute(fname: str) -> pl.DataFrame:
    path = os.path.join(DATA_DIR, fname)
    df = _read_file(path)
    return compute_duplicates_timeseries(df)


def plot_selected(sel_a, sel_b, metric_key):
    file_a = find_file(sel_a)
    file_b = find_file(sel_b)

    if file_a is None or file_b is None:
        print("No file matches one of the selections.")
        print("A:", sel_a, "->", file_a)
        print("B:", sel_b, "->", file_b)
        return

    fig = go.Figure()

    for sel, fname in [(sel_a, file_a), (sel_b, file_b)]:
        label = f"{sel['problem']} ({sel['mode']}, ct={sel['ctlevel']}, p={sel['psize']})"
        data = load_and_compute(fname)
        add_fig(fig, data, label, metric_key)

    fig.update_layout(
        title=METRICS[metric_key]["title"],
        xaxis_title="Generation",
        yaxis_title=METRICS[metric_key]["ytitle"],
    )
    fig.show()


ALL_MODES = ["baseline", "relaxed", "ranking"]


def plot_all_modes(base_sel, metric_key):
    fig = go.Figure()
    missing = []

    for mode in ALL_MODES:
        sel = dict(base_sel)
        sel["mode"] = mode
        fname = find_file(sel)
        if fname is None:
            missing.append(mode)
            continue

        label = f"{sel['problem']} ({mode}, ct={sel['ctlevel']}, p={sel['psize']})"
        data = load_and_compute(fname)
        add_fig(fig, data, label, metric_key)

    if missing:
        print(f"Missing files for modes: {missing} (selection base: {base_sel})")

    fig.update_layout(
        title=f"{METRICS[metric_key]['title']} — all modes",
        xaxis_title="Generation",
        yaxis_title=METRICS[metric_key]["ytitle"],
    )
    fig.show()


# -----------------------------
# UI (copied pattern)
# -----------------------------


def make_dependent_selector(prefix: str):
    w_problem = widgets.Dropdown(description=f"{prefix} problem:", layout=widgets.Layout(width="70%"))
    w_mode = widgets.Dropdown(description=f"{prefix} mode:", layout=widgets.Layout(width="70%"))
    w_gen = widgets.Dropdown(description=f"{prefix} gen:", layout=widgets.Layout(width="70%"))
    w_runs = widgets.Dropdown(description=f"{prefix} runs:", layout=widgets.Layout(width="70%"))
    w_ct = widgets.Dropdown(description=f"{prefix} ct level:", layout=widgets.Layout(width="70%"))
    w_psize = widgets.Dropdown(description=f"{prefix} psize:", layout=widgets.Layout(width="70%"))

    def set_options(w, opts, prefer=None):
        opts = uniq_sorted(opts)
        w.options = opts
        if not opts:
            w.value = None
            return
        if prefer in opts:
            w.value = prefer
        else:
            w.value = w.value if w.value in opts else opts[0]

    def refresh_from_problem(_=None):
        p = w_problem.value
        r = filter_rows(problem=p)
        set_options(w_mode, [x["mode"] for x in r], prefer=w_mode.value)
        refresh_from_mode()

    def refresh_from_mode(_=None):
        p, m = w_problem.value, w_mode.value
        r = filter_rows(problem=p, mode=m)
        set_options(w_gen, [x["gen"] for x in r], prefer=w_gen.value)
        refresh_from_gen()

    def refresh_from_gen(_=None):
        p, m, g = w_problem.value, w_mode.value, w_gen.value
        r = filter_rows(problem=p, mode=m, gen=g)
        set_options(w_runs, [x["runs"] for x in r], prefer=w_runs.value)
        refresh_from_runs()

    def refresh_from_runs(_=None):
        p, m, g, rn = w_problem.value, w_mode.value, w_gen.value, w_runs.value
        r = filter_rows(problem=p, mode=m, gen=g, runs=rn)
        set_options(w_ct, [x["ctlevel"] for x in r], prefer=w_ct.value)
        refresh_from_ct()

    def refresh_from_ct(_=None):
        p, m, g, rn, ct = w_problem.value, w_mode.value, w_gen.value, w_runs.value, w_ct.value
        r = filter_rows(problem=p, mode=m, gen=g, runs=rn, ctlevel=ct)
        set_options(w_psize, [x["psize"] for x in r], prefer=w_psize.value)

    w_problem.observe(refresh_from_problem, names="value")
    w_mode.observe(refresh_from_mode, names="value")
    w_gen.observe(refresh_from_gen, names="value")
    w_runs.observe(refresh_from_runs, names="value")
    w_ct.observe(refresh_from_ct, names="value")

    set_options(w_problem, [x["problem"] for x in rows])
    refresh_from_problem()

    def selection():
        return dict(
            problem=w_problem.value,
            mode=w_mode.value,
            gen=w_gen.value,
            runs=w_runs.value,
            ctlevel=w_ct.value,
            psize=w_psize.value,
        )

    controls = [w_problem, w_mode, w_gen, w_runs, w_ct, w_psize]
    return controls, selection


metric_dd = widgets.Dropdown(
    options=list(METRICS.keys()),
    description="Metric:",
    value=list(METRICS.keys())[0],
    layout=widgets.Layout(width="70%"),
)

a_controls, a_sel = make_dependent_selector("A")
b_controls, b_sel = make_dependent_selector("B")

MODE_ORDER = ["baseline", "relaxed", "ranking"]


def _next_mode(current: str) -> str:
    try:
        i = MODE_ORDER.index(current)
        for k in range(1, len(MODE_ORDER) + 1):
            cand = MODE_ORDER[(i + k) % len(MODE_ORDER)]
            if cand != current:
                return cand
    except ValueError:
        pass
    return next((m for m in MODE_ORDER if m != current), current)


def set_b_default_mode():
    sel = b_sel()
    r = filter_rows(problem=sel["problem"])
    modes = uniq_sorted([x["mode"] for x in r])
    if len(modes) > 1:
        a_mode = a_sel()["mode"]
        cand = next((mm for mm in modes if mm != a_mode), modes[0])
        b_controls[1].value = cand  # w_mode


set_b_default_mode()

button = widgets.Button(description="Plot", button_style="primary")
plot_all_button = widgets.Button(description="Plot all modes", button_style="info")
quick_compare = widgets.Button(description="Quick compare modes", button_style="info")
out = widgets.Output()


def on_click_plot(_):
    with out:
        out.clear_output(wait=True)
        plot_selected(a_sel(), b_sel(), metric_dd.value)


button.on_click(on_click_plot)


def on_click_plot_all(_):
    base = a_sel()
    base.pop("mode", None)

    with out:
        out.clear_output(wait=True)
        plot_all_modes(base, metric_dd.value)


plot_all_button.on_click(on_click_plot_all)


def on_click_quick_compare(_):
    a = a_sel()

    b_controls[0].value = a["problem"]
    b_controls[2].value = a["gen"]
    b_controls[3].value = a["runs"]
    b_controls[4].value = a["ctlevel"]
    b_controls[5].value = a["psize"]

    b_controls[1].value = _next_mode(a["mode"])

    with out:
        out.clear_output(wait=True)
        plot_selected(a_sel(), b_sel(), metric_dd.value)


quick_compare.on_click(on_click_quick_compare)

try:
    ui.close()
except Exception:
    pass

ui = widgets.VBox(
    [
        metric_dd,
        widgets.HTML("<b>Data 1 (A)</b>"),
        *a_controls,
        widgets.HTML("<b>Data 2 (B)</b>"),
        *b_controls,
        button,
        quick_compare,
        plot_all_button,
        out,
    ]
)

display(ui)


VBox(children=(Dropdown(description='Metric:', layout=Layout(width='70%'), options=('Duplicate ratio', 'Number…