In [12]:
import pandas as pd
import numpy as np
from pathlib import Path
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from ipywidgets import (
    VBox, HBox, FloatRangeSlider, IntSlider, Checkbox,
    Button, Layout, HTML, Output
)

# ---------------------------------------------------------------
# CONFIG
# ---------------------------------------------------------------
EXCEL_PATH = "../results/Gaia_homogeneous_target_selection_2025.08.15.xlsx"
EXPLICITLY_EXCLUDED_COLS = {"source_id_dr2", "source_id_dr3"}

EXCLUDE_FROM_PLOT = {
    "HWO_match",
    "LOPS2_match",
    "TESS_confirmed_match",
    "TESS_candidate_match",
    "sum_score"
}

SLIDERS_PER_ROW = 3
SLIDER_WIDTH = "26%"
SLIDER_COLUMN_GAP = "40px"

HZ_DETECTION_LIMIT_VAR = "HZ Detection Limit [M_Earth]"
HZ_DETECTION_LIMIT_CAP = 4.0

# Dynamic V_mag binning
V_MAG_COL = "V_mag"         # adjust if needed
N_VMAG_GROUPS = 5
# Five modern, visually distinct colors (ColorBrewer Set1 for strong contrast)
# Reordered so that visually similar colors are adjacent: red, orange, green, blue, purple
VMAG_COLORS = [
    "#e41a1c",  # red
    "#ff7f00",  # orange
    "#4daf4a",  # green
    "#377eb8",  # blue
    "#984ea3"   # purple
]
BARMODE = "stack"  # stacked histograms as requested

# Set a global alpha for histogram bars
BAR_OPACITY = 0.9  # 0.0 (fully transparent) to 1.0 (opaque)

# ---------------------------------------------------------------
# DATA LOADING
# ---------------------------------------------------------------
def load_data(path):
    if not Path(path).exists():
        raise FileNotFoundError(path)
    df = pd.read_excel(path)
    df.columns = [c.strip() for c in df.columns]
    if "Distance [pc]" not in df.columns and "Parallax" in df.columns:
        with np.errstate(divide="ignore", invalid="ignore"):
            df["Distance [pc]"] = 1000.0 / df["Parallax"]
            df.loc[~np.isfinite(df["Distance [pc]"]), "Distance [pc]"] = np.nan
    return df

df_raw = load_data(EXCEL_PATH)
if V_MAG_COL not in df_raw.columns:
    raise ValueError(f"Column '{V_MAG_COL}' not in data. Adjust V_MAG_COL.")
df = df_raw.copy()

# ---------------------------------------------------------------
# COLUMN SELECTION
# ---------------------------------------------------------------
def _norm_name(c):
    return c.lower().replace(" ", "").replace("_", "")

def is_source_id(col):
    return "sourceid" in _norm_name(col) or "source_id" in _norm_name(col)

def is_hip(col):
    return _norm_name(col) in {"hip", "hipnumber", "hipid"}

def is_explicit_excluded(col):
    return any(_norm_name(col) == _norm_name(e) for e in EXPLICITLY_EXCLUDED_COLS)

eligible_cols = []
for col in df.columns:
    if pd.api.types.is_numeric_dtype(df[col]) and not (is_source_id(col) or is_hip(col) or is_explicit_excluded(col)):
        eligible_cols.append(col)

ordered_cols = [c for c in df_raw.columns if c in eligible_cols]
ordered_cols += [c for c in eligible_cols if c not in ordered_cols]

# ---------------------------------------------------------------
# HELPERS
# ---------------------------------------------------------------
def is_discrete_like(series: pd.Series, discrete_unique_max=15):
    if series.dtype == bool:
        return True
    s = series.dropna()
    if s.empty:
        return True
    ucount = s.nunique()
    if pd.api.types.is_integer_dtype(s) and ucount <= discrete_unique_max:
        return True
    if ucount <= 3:
        return True
    return False

def auto_step(vmin, vmax):
    rng = vmax - vmin
    if rng <= 0:
        return 0.1
    raw = rng / 200
    mag = 10 ** np.floor(np.log10(raw))
    for m in [1, 2, 5]:
        if raw <= m * mag:
            return m * mag
    return mag

# ---------------------------------------------------------------
# BUILD SLIDERS
# ---------------------------------------------------------------
sliders = {}
filter_vars = []
filter_candidate_cols = list(EXCLUDE_FROM_PLOT) + [c for c in ordered_cols if c not in EXCLUDE_FROM_PLOT]
added = set()

for col in filter_candidate_cols:
    if col not in df.columns or col in added:
        continue
    s = df[col].dropna()
    if s.empty:
        continue
    vmin, vmax = float(s.min()), float(s.max())

    if col in EXCLUDE_FROM_PLOT or is_discrete_like(s):
        if np.allclose(s, np.round(s)):
            step = 1
            vmin_i = int(np.floor(vmin))
            vmax_i = int(np.ceil(vmax))
            slider = FloatRangeSlider(
                description=col,
                min=vmin_i, max=vmax_i,
                value=(vmin_i, vmax_i),
                step=step,
                layout=Layout(width=SLIDER_WIDTH),
                continuous_update=False,
                style={'description_width': 'initial'}
            )
        else:
            step = auto_step(vmin, vmax)
            slider = FloatRangeSlider(
                description=col,
                min=vmin, max=vmax,
                value=(vmin, vmax),
                step=step,
                layout=Layout(width=SLIDER_WIDTH),
                continuous_update=False,
                style={'description_width': 'initial'}
            )
    else:
        step = auto_step(vmin, vmax)
        slider = FloatRangeSlider(
            description=col,
            min=vmin, max=vmax,
            value=(vmin, vmax),
            step=step,
            layout=Layout(width=SLIDER_WIDTH),
            continuous_update=False,
            style={'description_width': 'initial'}
        )

    if col == HZ_DETECTION_LIMIT_VAR:
        lower = slider.min
        upper = min(HZ_DETECTION_LIMIT_CAP, slider.max)
        if upper >= lower:
            slider.value = (lower, upper)

    sliders[col] = slider
    filter_vars.append(col)
    added.add(col)

# ---------------------------------------------------------------
# BINNING
# ---------------------------------------------------------------
def select_bins(series: pd.Series, min_bins: int, max_bins: int):
    s = series.dropna().astype(float)
    n = s.shape[0]
    if n == 0:
        return None
    if is_discrete_like(s):
        unique_vals = np.sort(s.unique())
        if unique_vals.size == 1:
            return np.array([unique_vals[0] - 0.5, unique_vals[0] + 0.5])
        diffs = np.diff(unique_vals) / 2
        return np.concatenate([
            [unique_vals[0] - diffs[0]],
            unique_vals[:-1] + diffs,
            [unique_vals[-1] + diffs[-1]]
        ])
    s_min, s_max = s.min(), s.max()
    rng = s_max - s_min
    if rng == 0:
        return np.array([s_min - 0.5, s_min + 0.5])
    q25, q75 = np.percentile(s, [25, 75])
    iqr = q75 - q25
    std = s.std(ddof=1) if n > 1 else 0.0
    widths = []
    if iqr > 0:
        h_fd = 2 * iqr / (n ** (1/3))
        if h_fd > 0:
            widths.append(h_fd)
    if std > 0:
        h_scott = 3.5 * std / (n ** (1/3))
        if h_scott > 0:
            widths.append(h_scott)
    k_sturges = np.ceil(np.log2(n) + 1) if n > 1 else 1
    if rng > 0 and k_sturges > 0:
        widths.append(rng / k_sturges)
    if n < 30:
        approx_bins_small = max(int(np.sqrt(n)), len(np.unique(s)))
        approx_bins_small = int(np.clip(approx_bins_small, 3, max_bins))
        return np.linspace(s_min, s_max, approx_bins_small + 1)
    if not widths:
        k = int(np.sqrt(n))
        k = int(np.clip(k, min_bins, max_bins))
        return np.linspace(s_min, s_max, k + 1)
    h = np.median(widths)
    if h <= 0 or not np.isfinite(h):
        k = int(np.sqrt(n))
    else:
        k = int(np.round((s_max - s_min) / h))
    k = int(np.clip(k, min_bins, max_bins))
    k = max(k, 1)
    return np.linspace(s_min, s_max, k + 1)

# ---------------------------------------------------------------
# FILTER APPLICATION
# ---------------------------------------------------------------
def current_filter_ranges():
    return {c: sl.value for c, sl in sliders.items()}

def apply_filters(df_in: pd.DataFrame):
    mask = pd.Series(True, index=df_in.index)
    for col, (vmin, vmax) in current_filter_ranges().items():
        full_min, full_max = sliders[col].min, sliders[col].max
        if (vmin, vmax) != (full_min, full_max):
            mask &= (df_in[col] >= vmin) & (df_in[col] <= vmax)
    return df_in[mask].copy()

# ---------------------------------------------------------------
# WIDGETS
# ---------------------------------------------------------------
auto_bins_checkbox = Checkbox(value=True, description="Auto bins per variable", indent=False)
global_bins_slider = IntSlider(value=40, min=5, max=150, step=1, description="Global bins", layout=Layout(width="225px"))
min_bins_slider = IntSlider(value=10, min=2, max=50, step=1, description="Min bins", layout=Layout(width="205px"))
max_bins_slider = IntSlider(value=70, min=20, max=200, step=1, description="Max bins", layout=Layout(width="215px"))
normalize_checkbox = Checkbox(value=False, description="Normalize (density)", indent=False)
update_button = Button(description="Update", button_style="primary")
reset_button = Button(description="Reset", button_style="warning")
status_html = HTML("")
out_plot = Output()

# ---------------------------------------------------------------
# DRAW
# ---------------------------------------------------------------
def format_bin_label(edges, i):
    return f"{edges[i]:.2f}–{edges[i+1]:.2f}"

def draw():
    with out_plot:
        out_plot.clear_output(wait=True)

        filtered = apply_filters(df)
        status_html.value = f"<b>Filtered rows:</b> {len(filtered)} / {len(df)}"
        if filtered.empty:
            print("No data after filtering.")
            return

        # Dynamic V_mag bin edges based on filtered V_mag
        vmag_filtered = filtered[V_MAG_COL].dropna()
        if vmag_filtered.empty:
            print("No V_mag values after filtering.")
            return

        vmin_cur, vmax_cur = vmag_filtered.min(), vmag_filtered.max()
        if np.isclose(vmin_cur, vmax_cur):
            # Degenerate (single value) case
            edges = np.array([vmin_cur - 0.5, vmin_cur + 0.5])
        else:
            edges = np.linspace(vmin_cur, vmax_cur, N_VMAG_GROUPS + 1)

        vmag_bins = pd.cut(filtered[V_MAG_COL], bins=edges, include_lowest=True, labels=False)
        n_bins_in_use = len(edges) - 1  # can be 1 if degenerate

        plot_vars = [v for v in ordered_cols if v not in EXCLUDE_FROM_PLOT]
        if not plot_vars:
            print("No variables to plot (all excluded).")
            return

        n_vars = len(plot_vars)
        base_cols = int(np.ceil(np.sqrt(n_vars)))
        n_cols = max(3, min(6, base_cols))
        n_rows = int(np.ceil(n_vars / n_cols))

        fig = make_subplots(
            rows=n_rows, cols=n_cols,
            horizontal_spacing=0.04,
            vertical_spacing=0.07
        )

        histnorm = "probability density" if normalize_checkbox.value else None
        use_auto = auto_bins_checkbox.value
        min_bins = min_bins_slider.value
        max_bins = max_bins_slider.value
        global_bins = global_bins_slider.value

        first_variable = True

        row = col = 1
        for var in plot_vars:
            series_full = filtered[var]
            s_all = series_full.dropna()
            if s_all.empty:
                if col == n_cols:
                    row += 1
                    col = 1
                else:
                    col += 1
                continue

            discrete = is_discrete_like(s_all)
            if discrete:
                categories = np.sort(series_full.dropna().unique())
            else:
                if use_auto:
                    edges_var = select_bins(s_all, min_bins=min_bins, max_bins=max_bins)
                    if edges_var is None or len(edges_var) < 2:
                        edges_var = np.array([s_all.min() - 0.5, s_all.min() + 0.5])
                else:
                    edges_var = None  # Plotly computes nbinsx

            for bin_idx in range(n_bins_in_use):
                group_mask = (vmag_bins == bin_idx)
                s_bin = filtered.loc[group_mask, var].dropna()
                if s_bin.empty:
                    continue

                color = VMAG_COLORS[bin_idx % len(VMAG_COLORS)]
                name = f"V_mag {format_bin_label(edges, bin_idx)}"
                show_legend = first_variable

                if discrete:
                    counts = s_bin.value_counts()
                    y_vals = [counts.get(cat, 0) for cat in categories]
                    if histnorm:
                        tot = sum(y_vals)
                        if tot > 0:
                            y_vals = [v / tot for v in y_vals]
                    fig.add_trace(
                        go.Bar(
                            x=categories.astype(float),
                            y=y_vals,
                            marker=dict(
                                color=color,
                                line=dict(width=0.3, color="rgba(0,0,0,0.5)"),
                                opacity=BAR_OPACITY
                            ),
                            opacity=BAR_OPACITY,
                            name=name,
                            legendgroup="vmag",
                            showlegend=show_legend,
                            hovertemplate=f"{var}<br>%{{x}} : %{{y}}<br>{name}<extra></extra>"
                        ),
                        row=row, col=col
                    )
                else:
                    if use_auto:
                        bin_size = None
                        if edges_var is not None and len(edges_var) >= 2:
                            bin_size = edges_var[1] - edges_var[0]
                            fig.add_trace(
                                go.Histogram(
                                    x=s_bin,
                                    xbins=dict(start=edges_var[0], end=edges_var[-1], size=bin_size),
                                    histnorm=histnorm,
                                    marker=dict(
                                        color=color,
                                        line=dict(width=0.25, color="rgba(0,0,0,0.5)"),
                                        opacity=BAR_OPACITY
                                    ),
                                    opacity=BAR_OPACITY,
                                    name=name,
                                    legendgroup="vmag",
                                    showlegend=show_legend,
                                    hovertemplate=f"{var}<br>%{{x}} : %{{y}}<br>{name}<extra></extra>"
                                ),
                                row=row, col=col
                            )
                        else:
                            fig.add_trace(
                                go.Histogram(
                                    x=s_bin,
                                    nbinsx=1,
                                    histnorm=histnorm,
                                    marker=dict(
                                        color=color,
                                        line=dict(width=0.25, color="rgba(0,0,0,0.5)"),
                                        opacity=BAR_OPACITY
                                    ),
                                    opacity=BAR_OPACITY,
                                    name=name,
                                    legendgroup="vmag",
                                    showlegend=show_legend,
                                    hovertemplate=f"{var}<br>%{{x}} : %{{y}}<br>{name}<extra></extra>"
                                ),
                                row=row, col=col
                            )
                    else:
                        fig.add_trace(
                            go.Histogram(
                                x=s_bin,
                                nbinsx=global_bins,
                                histnorm=histnorm,
                                marker=dict(
                                    color=color,
                                    line=dict(width=0.25, color="rgba(0,0,0,0.5)"),
                                    opacity=BAR_OPACITY
                                ),
                                opacity=BAR_OPACITY,
                                name=name,
                                legendgroup="vmag",
                                showlegend=show_legend,
                                hovertemplate=f"{var}<br>%{{x}} : %{{y}}<br>{name}<extra></extra>"
                            ),
                            row=row, col=col
                        )

            fig.update_xaxes(title_text=var, row=row, col=col)
            first_variable = False

            if col == n_cols:
                col = 1
                row += 1
            else:
                col += 1

        height = 250 * n_rows + 80
        fig.update_layout(
            title=f"Stacked Histograms by Dynamic V_mag Bins (Vars={n_vars}; Excluded={len(EXCLUDE_FROM_PLOT)})",
            bargap=0.03,
            height=height,
            template="plotly_white",
            margin=dict(l=60, r=25, t=70, b=60),
            barmode=BARMODE
        )
        fig.update_yaxes(rangemode="tozero")
        fig.show()

# ---------------------------------------------------------------
# CALLBACKS
# ---------------------------------------------------------------
def on_update(_):
    draw()

def on_reset(_):
    for col, sl in sliders.items():
        sl.value = (sl.min, sl.max)
        if col == HZ_DETECTION_LIMIT_VAR:
            upper = min(HZ_DETECTION_LIMIT_CAP, sl.max)
            if upper >= sl.min:
                sl.value = (sl.min, upper)
    auto_bins_checkbox.value = True
    global_bins_slider.value = 40
    min_bins_slider.value = 10
    max_bins_slider.value = 70
    normalize_checkbox.value = False
    draw()

update_button.on_click(on_update)
reset_button.on_click(on_reset)

for w in list(sliders.values()) + [
    auto_bins_checkbox, global_bins_slider,
    min_bins_slider, max_bins_slider,
    normalize_checkbox
]:
    w.observe(lambda ch: draw(), names="value")

# ---------------------------------------------------------------
# LAYOUT
# ---------------------------------------------------------------
filter_header = HTML(f"<b>Filters (n={len(sliders)})</b>")
slider_widgets = [sliders[v] for v in filter_vars]

slider_rows = []
for i in range(0, len(slider_widgets), SLIDERS_PER_ROW):
    row_widgets = slider_widgets[i:i+SLIDERS_PER_ROW]
    for j, w in enumerate(row_widgets):
        if hasattr(w, "layout") and j < len(row_widgets) - 1:
            w.layout.margin = f"0 {SLIDER_COLUMN_GAP} 0 0"
    slider_rows.append(HBox(row_widgets, layout=Layout(width="100%", justify_content="flex-start")))

legend_note = HTML("<i>V_mag bin edges update dynamically with current filters.</i>")

controls_column = [
    HTML("<h3>GAIA Sample – Dynamic V_mag Binned Stacked Histograms</h3>"),
    legend_note,
    filter_header,
    *slider_rows,
    HTML("<b>Histogram Controls</b>"),
    HBox([auto_bins_checkbox, normalize_checkbox]),
    HBox([min_bins_slider, max_bins_slider, global_bins_slider]),
    HBox([update_button, reset_button, status_html]),
    out_plot
]

ui = VBox(controls_column)
display(ui)
draw()


VBox(children=(HTML(value='<h3>GAIA Sample – Dynamic V_mag Binned Stacked Histograms</h3>'), HTML(value='<i>V_…