In [2]:
import pandas as pd
import numpy as np
from pathlib import Path
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import display
import base64
from ipywidgets import (
    VBox, HBox, FloatRangeSlider, IntSlider, Checkbox,
    Button, Layout, HTML, Output, Dropdown
)

# Try to import FileDownload (ipywidgets >=7.6 / 8.x); fall back if unavailable
try:
    from ipywidgets import FileDownload
    HAS_FILEDOWNLOAD = True
except ImportError:
    HAS_FILEDOWNLOAD = False

# ---------------------------------------------------------------
# CONFIG
# ---------------------------------------------------------------
INPUT_FILES = [
    ("Gaia_homogeneous_target_selection_2025.08.22_10_granulation.xlsx", "+10% granulation RMS"),
    ("Gaia_homogeneous_target_selection_2025.08.22_50_granulation.xlsx", "+50% granulation RMS"),
    ("Gaia_homogeneous_target_selection_2025.08.22_100_granulation.xlsx", "+100% granulation RMS")
]
DEFAULT_FILE_PATH = "Gaia_homogeneous_target_selection_2025.08.22_100_granulation.xlsx"

EXPLICITLY_EXCLUDED_COLS = {"source_id_dr2", "source_id_dr3"}
EXCLUDE_FROM_PLOT = {
    "HWO_match",
    "LOPS2_match",
    "TESS_confirmed_match",
    "TESS_candidate_match",
    "sum_score"
}

SLIDERS_PER_ROW = 5
SLIDER_WIDTH = "26%"
SLIDER_COLUMN_GAP = "40px"

HZ_DETECTION_LIMIT_VAR = "HZ Detection Limit [M_Earth]"
HZ_DETECTION_LIMIT_CAP = 4.0

DEFAULT_COLOR_VAR = "T_eff [K]"
SECOND_COLOR_VAR = "V_mag"

N_COLOR_GROUPS = 5
COLOR_PALETTE = [
    "#e41a1c",
    "#ff7f00",
    "#4daf4a",
    "#377eb8",
    "#984ea3"
]
BARMODE = "stack"
BAR_OPACITY = 0.9

# NEW: constant slider step configuration
CONSTANT_SLIDER_STEP = 0.01
SLIDER_READOUT_FORMAT = '.2f'   # Change to '.2f' if you prefer fewer decimals
FORCE_CONTINUOUS_UPDATE = True  # Set False if performance becomes sluggish

# ---------------------------------------------------------------
# DATA LOADING
# ---------------------------------------------------------------
def load_data(path):
    if not Path(path).exists():
        raise FileNotFoundError(path)
    df = pd.read_excel(
        path,
        dtype={'source_id': str, 'source_id_dr2': str, 'source_id_dr3': str, 'HIP Number': str}
    )
    df.columns = [c.strip() for c in df.columns]
    if "Distance [pc]" not in df.columns and "Parallax" in df.columns:
        with np.errstate(divide="ignore", invalid="ignore"):
            df["Distance [pc]"] = 1000.0 / df["Parallax"]
            df.loc[~np.isfinite(df["Distance [pc]"]), "Distance [pc]"] = np.nan
    return df

file_dropdown_options = [(label, path) for path, label in INPUT_FILES]
file_dropdown = Dropdown(
    options=file_dropdown_options,
    value=DEFAULT_FILE_PATH,
    description="Dataset:",
    style={'description_width': 'initial'},
    layout=Layout(width="300px")
)

df_raw = load_data(DEFAULT_FILE_PATH)
df = df_raw.copy()

# ---------------------------------------------------------------
# COLUMN SELECTION
# ---------------------------------------------------------------
def _norm_name(c):
    return c.lower().replace(" ", "").replace("_", "")

def is_source_id(col):
    return "sourceid" in _norm_name(col) or "source_id" in _norm_name(col)

def is_hip(col):
    return _norm_name(col) in {"hip", "hipnumber", "hipid"}

def is_explicit_excluded(col):
    return any(_norm_name(col) == _norm_name(e) for e in EXPLICITLY_EXCLUDED_COLS)

def get_eligible_cols(dataframe):
    eligible = []
    for col in dataframe.columns:
        if pd.api.types.is_numeric_dtype(dataframe[col]) and not (
            is_source_id(col) or is_hip(col) or is_explicit_excluded(col)
        ):
            eligible.append(col)
    return eligible

def get_ordered_cols(dataframe, eligible):
    ordered = [c for c in dataframe.columns if c in eligible]
    ordered += [c for c in eligible if c not in ordered]
    return ordered

eligible_cols = get_eligible_cols(df)
ordered_cols = get_ordered_cols(df_raw, eligible_cols)

# ---------------------------------------------------------------
# HELPERS
# ---------------------------------------------------------------
def is_discrete_like(series: pd.Series, discrete_unique_max=15):
    # This function is now only used for deciding histogram method; does not affect slider stepping.
    if series.dtype == bool:
        return True
    s = series.dropna()
    if s.empty:
        return True
    ucount = s.nunique()
    if pd.api.types.is_integer_dtype(s) and ucount <= discrete_unique_max:
        return True
    if ucount <= 3:
        return True
    return False

# ---------------------------------------------------------------
# BUILD SLIDERS (CONSTANT STEP)
# ---------------------------------------------------------------
def build_sliders(dataframe):
    sliders_dict = {}
    filter_vars_list = []
    filter_candidate_cols = list(EXCLUDE_FROM_PLOT) + [c for c in ordered_cols if c not in EXCLUDE_FROM_PLOT]
    added = set()

    for col in filter_candidate_cols:
        if col not in dataframe.columns or col in added:
            continue
        s = dataframe[col].dropna()
        if s.empty:
            continue

        vmin, vmax = float(s.min()), float(s.max())

        # Constant step for all numeric sliders
        step = CONSTANT_SLIDER_STEP

        # OPTIONAL: keep integer-like columns on step=1
        # if np.allclose(s, np.round(s)) and s.nunique() <= 50:
        #     step = 1

        slider = FloatRangeSlider(
            description=col,
            min=vmin,
            max=vmax,
            value=(vmin, vmax),
            step=step,
            layout=Layout(width=SLIDER_WIDTH),
            continuous_update=FORCE_CONTINUOUS_UPDATE,
            style={'description_width': 'initial'},
            readout_format=SLIDER_READOUT_FORMAT
        )

        if col == HZ_DETECTION_LIMIT_VAR:
            lower = slider.min
            upper = min(HZ_DETECTION_LIMIT_CAP, slider.max)
            if upper >= lower:
                slider.value = (lower, upper)

        sliders_dict[col] = slider
        filter_vars_list.append(col)
        added.add(col)

    return sliders_dict, filter_vars_list

sliders, filter_vars = build_sliders(df)

# ---------------------------------------------------------------
# BINNING (unchanged logic)
# ---------------------------------------------------------------
def select_bins(series: pd.Series, min_bins: int, max_bins: int):
    s = series.dropna().astype(float)
    n = s.shape[0]
    if n == 0:
        return None
    if is_discrete_like(s):
        unique_vals = np.sort(s.unique())
        if unique_vals.size == 1:
            return np.array([unique_vals[0] - 0.5, unique_vals[0] + 0.5])
        diffs = np.diff(unique_vals) / 2
        return np.concatenate([
            [unique_vals[0] - diffs[0]],
            unique_vals[:-1] + diffs,
            [unique_vals[-1] + diffs[-1]]
        ])
    s_min, s_max = s.min(), s.max()
    rng = s_max - s_min
    if rng == 0:
        return np.array([s_min - 0.5, s_min + 0.5])
    q25, q75 = np.percentile(s, [25, 75])
    iqr = q75 - q25
    std = s.std(ddof=1) if n > 1 else 0.0
    widths = []
    if iqr > 0:
        h_fd = 2 * iqr / (n ** (1/3))
        if h_fd > 0:
            widths.append(h_fd)
    if std > 0:
        h_scott = 3.5 * std / (n ** (1/3))
        if h_scott > 0:
            widths.append(h_scott)
    k_sturges = np.ceil(np.log2(n) + 1) if n > 1 else 1
    if rng > 0 and k_sturges > 0:
        widths.append(rng / k_sturges)
    if n < 30:
        approx_bins_small = max(int(np.sqrt(n)), len(np.unique(s)))
        approx_bins_small = int(np.clip(approx_bins_small, 3, max_bins))
        return np.linspace(s_min, s_max, approx_bins_small + 1)
    if not widths:
        k = int(np.sqrt(n))
        k = int(np.clip(k, min_bins, max_bins))
        return np.linspace(s_min, s_max, k + 1)
    h = np.median(widths)
    if h <= 0 or not np.isfinite(h):
        k = int(np.sqrt(n))
    else:
        k = int(np.round((s_max - s_min) / h))
    k = int(np.clip(k, min_bins, max_bins))
    k = max(k, 1)
    return np.linspace(s_min, s_max, k + 1)

# ---------------------------------------------------------------
# FILTER APPLICATION
# ---------------------------------------------------------------
def current_filter_ranges():
    return {c: sl.value for c, sl in sliders.items()}

def apply_filters(df_in: pd.DataFrame):
    mask = pd.Series(True, index=df_in.index)
    for col, (vmin, vmax) in current_filter_ranges().items():
        if col in df_in.columns:
            full_min, full_max = sliders[col].min, sliders[col].max
            if (vmin, vmax) != (full_min, full_max):
                mask &= (df_in[col] >= vmin) & (df_in[col] <= vmax)
    return df_in[mask].copy()

# ---------------------------------------------------------------
# BROWSER DOWNLOAD FUNCTION
# ---------------------------------------------------------------
def download_csv_in_browser(df_to_download, filename):
    csv = df_to_download.to_csv(index=False)
    b64 = base64.b64encode(csv.encode()).decode()
    href = f'<a download="{filename}" href="data:text/csv;base64,{b64}" target="_blank">Download CSV File</a>'
    return HTML(href)

# ---------------------------------------------------------------
# WIDGETS
# ---------------------------------------------------------------
def update_color_dropdown(eligible_cols_list):
    color_var_options = [col for col in eligible_cols_list if col not in EXCLUDE_FROM_PLOT]
    if DEFAULT_COLOR_VAR not in color_var_options and DEFAULT_COLOR_VAR in df.columns:
        color_var_options.insert(0, DEFAULT_COLOR_VAR)
    if SECOND_COLOR_VAR not in color_var_options and SECOND_COLOR_VAR in df.columns:
        color_var_options.insert(1, SECOND_COLOR_VAR)

    dropdown_options = []
    for col in color_var_options:
        if col in (DEFAULT_COLOR_VAR, SECOND_COLOR_VAR):
            dropdown_options.append((f"★ {col} (default)", col))
        else:
            dropdown_options.append((col, col))

    return Dropdown(
        options=dropdown_options,
        value=DEFAULT_COLOR_VAR if DEFAULT_COLOR_VAR in df.columns and any(opt[1] == DEFAULT_COLOR_VAR for opt in dropdown_options)
              else dropdown_options[0][1],
        description="Color by:",
        style={'description_width': 'initial'},
        layout=Layout(width="300px")
    )

color_var_dropdown = update_color_dropdown(eligible_cols)
auto_bins_checkbox = Checkbox(value=True, description="Auto bins per variable", indent=False)
global_bins_slider = IntSlider(value=40, min=5, max=150, step=1, description="Global bins", layout=Layout(width="225px"))
min_bins_slider = IntSlider(value=10, min=2, max=50, step=1, description="Min bins", layout=Layout(width="205px"))
max_bins_slider = IntSlider(value=70, min=20, max=200, step=1, description="Max bins", layout=Layout(width="215px"))
reset_button = Button(description="Reset", button_style="warning")
export_button = Button(description="Export Filtered CSV", button_style="success", icon="save")
status_html = HTML("")
out_plot = Output()
filtered_out = Output()
file_loading_status = HTML("")

if HAS_FILEDOWNLOAD:
    def _download_data():
        return apply_filters(df).to_csv(index=False).encode("utf-8")
    download_button = FileDownload(
        data=_download_data,
        filename="filtered_subset.csv",
        description="Download current subset",
        button_style="info"
    )
else:
    download_button = HTML("<i>FileDownload widget not available</i>")

# ---------------------------------------------------------------
# DRAW
# ---------------------------------------------------------------
def format_bin_label(edges, i):
    return f"{edges[i]:.2f}–{edges[i+1]:.2f}"

def draw():
    with out_plot:
        out_plot.clear_output(wait=True)

        filtered = apply_filters(df)
        status_html.value = f"<b>Filtered rows:</b> {len(filtered)} / {len(df)}"
        if filtered.empty:
            print("No data after filtering.")
            return

        color_var = color_var_dropdown.value
        if color_var not in filtered.columns:
            print(f"Selected color variable '{color_var}' not found in data.")
            return

        color_filtered = filtered[color_var].dropna()
        if color_filtered.empty:
            print(f"No {color_var} values after filtering.")
            return

        vmin_cur, vmax_cur = color_filtered.min(), color_filtered.max()
        if np.isclose(vmin_cur, vmax_cur):
            edges = np.array([vmin_cur - 0.5, vmin_cur + 0.5])
        else:
            edges = np.linspace(vmin_cur, vmax_cur, N_COLOR_GROUPS + 1)

        color_bins = pd.cut(filtered[color_var], bins=edges, include_lowest=True, labels=False)
        n_bins_in_use = len(edges) - 1

        plot_vars = [v for v in ordered_cols if v not in EXCLUDE_FROM_PLOT]
        if not plot_vars:
            print("No variables to plot (all excluded).")
            return

        n_vars = len(plot_vars)
        base_cols = int(np.ceil(np.sqrt(n_vars)))
        n_cols = max(3, min(6, base_cols))
        n_rows = int(np.ceil(n_vars / n_cols))

        fig = make_subplots(
            rows=n_rows, cols=n_cols,
            horizontal_spacing=0.04,
            vertical_spacing=0.07
        )

        use_auto = auto_bins_checkbox.value
        min_bins = min_bins_slider.value
        max_bins = max_bins_slider.value
        global_bins = global_bins_slider.value

        first_variable = True
        row = col = 1

        for var in plot_vars:
            series_full = filtered[var]
            s_all = series_full.dropna()
            if s_all.empty:
                if col == n_cols:
                    row += 1
                    col = 1
                else:
                    col += 1
                continue

            discrete = is_discrete_like(s_all)
            if discrete:
                categories = np.sort(series_full.dropna().unique())
            else:
                if use_auto:
                    edges_var = select_bins(s_all, min_bins=min_bins, max_bins=max_bins)
                    if edges_var is None or len(edges_var) < 2:
                        edges_var = np.array([s_all.min() - 0.5, s_all.min() + 0.5])
                else:
                    edges_var = None

            for bin_idx in range(n_bins_in_use):
                group_mask = (color_bins == bin_idx)
                s_bin = filtered.loc[group_mask, var].dropna()
                if s_bin.empty:
                    continue

                color = COLOR_PALETTE[bin_idx % len(COLOR_PALETTE)]
                name = f"{color_var} {format_bin_label(edges, bin_idx)}"
                show_legend = first_variable

                if discrete:
                    counts = s_bin.value_counts()
                    y_vals = [counts.get(cat, 0) for cat in categories]
                    fig.add_trace(
                        go.Bar(
                            x=categories.astype(float),
                            y=y_vals,
                            marker=dict(
                                color=color,
                                line=dict(width=0.3, color="rgba(0,0,0,0.5)"),
                                opacity=BAR_OPACITY
                            ),
                            opacity=BAR_OPACITY,
                            name=name,
                            legendgroup="colorvar",
                            showlegend=show_legend,
                            hovertemplate=f"{var}<br>%{{x}} : %{{y}}<br>{name}<extra></extra>"
                        ),
                        row=row, col=col
                    )
                else:
                    if use_auto:
                        if edges_var is not None and len(edges_var) >= 2:
                            bin_size = edges_var[1] - edges_var[0]
                            fig.add_trace(
                                go.Histogram(
                                    x=s_bin,
                                    xbins=dict(start=edges_var[0], end=edges_var[-1], size=bin_size),
                                    marker=dict(
                                        color=color,
                                        line=dict(width=0.25, color="rgba(0,0,0,0.5)"),
                                        opacity=BAR_OPACITY
                                    ),
                                    opacity=BAR_OPACITY,
                                    name=name,
                                    legendgroup="colorvar",
                                    showlegend=show_legend,
                                    hovertemplate=f"{var}<br>%{{x}} : %{{y}}<br>{name}<extra></extra>"
                                ),
                                row=row, col=col
                            )
                        else:
                            fig.add_trace(
                                go.Histogram(
                                    x=s_bin,
                                    nbinsx=1,
                                    marker=dict(
                                        color=color,
                                        line=dict(width=0.25, color="rgba(0,0,0,0.5)"),
                                        opacity=BAR_OPACITY
                                    ),
                                    opacity=BAR_OPACITY,
                                    name=name,
                                    legendgroup="colorvar",
                                    showlegend=show_legend,
                                    hovertemplate=f"{var}<br>%{{x}} : %{{y}}<br>{name}<extra></extra>"
                                ),
                                row=row, col=col
                            )
                    else:
                        fig.add_trace(
                            go.Histogram(
                                x=s_bin,
                                nbinsx=global_bins,
                                marker=dict(
                                    color=color,
                                    line=dict(width=0.25, color="rgba(0,0,0,0.5)"),
                                    opacity=BAR_OPACITY
                                ),
                                opacity=BAR_OPACITY,
                                name=name,
                                legendgroup="colorvar",
                                showlegend=show_legend,
                                hovertemplate=f"{var}<br>%{{x}} : %{{y}}<br>{name}<extra></extra>"
                            ),
                            row=row, col=col
                        )

            fig.update_xaxes(title_text=var, row=row, col=col)
            first_variable = False

            if col == n_cols:
                col = 1
                row += 1
            else:
                col += 1

        height = 250 * n_rows + 80
        selected_file_label = next((label for label, path in file_dropdown_options if path == file_dropdown.value), "Unknown dataset")

        fig.update_layout(
            title=f"{selected_file_label} - Histograms colored by {color_var} (Vars={n_vars}; Excluded={len(EXCLUDE_FROM_PLOT)})",
            bargap=0.03,
            height=height,
            template="plotly_white",
            margin=dict(l=60, r=25, t=70, b=60),
            barmode=BARMODE
        )
        fig.update_yaxes(rangemode="tozero")
        fig.show()

# ---------------------------------------------------------------
# CALLBACKS
# ---------------------------------------------------------------
def on_reset(_):
    for col, sl in sliders.items():
        sl.value = (sl.min, sl.max)
        if col == HZ_DETECTION_LIMIT_VAR:
            upper = min(HZ_DETECTION_LIMIT_CAP, sl.max)
            if upper >= sl.min:
                sl.value = (sl.min, upper)
    auto_bins_checkbox.value = True
    global_bins_slider.value = 40
    min_bins_slider.value = 10
    max_bins_slider.value = 70
    default_color = DEFAULT_COLOR_VAR if DEFAULT_COLOR_VAR in df.columns else None
    if default_color and any(opt[1] == default_color for opt in color_var_dropdown.options):
        color_var_dropdown.value = default_color
    draw()

def on_export(_):
    filtered = apply_filters(df)
    ts = pd.Timestamp.utcnow().strftime("%Y%m%d_%H%M%S")
    filename = f"filtered_subset_{ts}.csv"
    download_link = download_csv_in_browser(filtered, filename)
    with filtered_out:
        filtered_out.clear_output()
        print(f"Ready to export {len(filtered)} rows")
        display(download_link)
        display(filtered.head(20))

def on_file_change(change):
    global df, df_raw, sliders, filter_vars, eligible_cols, ordered_cols, color_var_dropdown
    if change['type'] == 'change' and change['name'] == 'value':
        file_loading_status.value = "<b>Loading new dataset...</b>"
        try:
            new_file_path = change['new']
            df_raw = load_data(new_file_path)
            df = df_raw.copy()
            eligible_cols = get_eligible_cols(df)
            ordered_cols = get_ordered_cols(df_raw, eligible_cols)
            sliders, filter_vars = build_sliders(df)
            color_var_dropdown = update_color_dropdown(eligible_cols)
            
            rebuild_ui()  # This will attach the observers
            file_loading_status.value = "<b>Dataset loaded successfully!</b>"
            draw()
        except Exception as e:
            file_loading_status.value = f"<b style='color:red'>Error loading file: {str(e)}</b>"


reset_button.on_click(on_reset)
export_button.on_click(on_export)
file_dropdown.observe(on_file_change)

# ---------------------------------------------------------------
# LAYOUT
# ---------------------------------------------------------------
def rebuild_ui():
    global ui, slider_widgets, slider_rows, filter_header, controls_column

    filter_header = HTML(f"<b>Filters (n={len(sliders)})</b>")
    slider_widgets = [sliders[v] for v in filter_vars]

    # Attach observers to new sliders
    for var, slider in sliders.items():
        # Clear any existing observers first (to avoid duplicates)
        slider.unobserve_all()
        slider.observe(create_observer(slider), names="value")

    # Attach observer to color dropdown
    color_var_dropdown.unobserve_all()
    color_var_dropdown.observe(create_observer(color_var_dropdown), names="value")

    slider_rows = []
    for i in range(0, len(slider_widgets), SLIDERS_PER_ROW):
        row_widgets = slider_widgets[i:i+SLIDERS_PER_ROW]
        for j, w in enumerate(row_widgets):
            if hasattr(w, "layout") and j < len(row_widgets) - 1:
                w.layout.margin = f"0 {SLIDER_COLUMN_GAP} 0 0"
        slider_rows.append(HBox(row_widgets, layout=Layout(width="100%", justify_content="flex-start")))

    # Add the "scroll to bottom to download" text next to the Export Filtered CSV button
    export_controls = HBox(
        [export_button, HTML("<span style='margin-left:10px; color: #888;'>scroll to bottom to download</span>")],
        layout=Layout(justify_content="flex-start", column_gap="15px")
    )

    controls_column = [
        HTML("<h3>GAIA Sample – Dynamic Binned Stacked Histograms</h3>"),
        HBox([file_dropdown, file_loading_status], layout=Layout(width="100%")),
        HBox([color_var_dropdown]),
        filter_header,
        *slider_rows,
        HTML("<b>Histogram Controls</b>"),
        HBox([auto_bins_checkbox]),
        HBox([min_bins_slider, max_bins_slider, global_bins_slider]),
        HBox([reset_button, status_html]),
        export_controls,
        out_plot,
        HTML("<b>Export / Preview</b>"),
        filtered_out
    ]
    ui.children = controls_column

def create_observer(widget):
    return lambda change: draw()

auto_bins_checkbox.observe(create_observer(auto_bins_checkbox), names="value")
global_bins_slider.observe(create_observer(global_bins_slider), names="value")
min_bins_slider.observe(create_observer(min_bins_slider), names="value")
max_bins_slider.observe(create_observer(max_bins_slider), names="value")
color_var_dropdown.observe(create_observer(color_var_dropdown), names="value")

for var, slider in sliders.items():
    slider.observe(create_observer(slider), names="value")

filter_header = HTML(f"<b>Filters (n={len(sliders)})</b>")
slider_widgets = [sliders[v] for v in filter_vars]
slider_rows = []
for i in range(0, len(slider_widgets), SLIDERS_PER_ROW):
    row_widgets = slider_widgets[i:i+SLIDERS_PER_ROW]
    for j, w in enumerate(row_widgets):
        if hasattr(w, "layout") and j < len(row_widgets) - 1:
            w.layout.margin = f"0 {SLIDER_COLUMN_GAP} 0 0"
    slider_rows.append(HBox(row_widgets, layout=Layout(width="100%", justify_content="flex-start")))

# Add the "scroll to bottom to download" text next to the Export Filtered CSV button
export_controls = HBox(
    [export_button, HTML("<span style='margin-left:10px; color: #888;'>scroll to bottom to download</span>")],
    layout=Layout(justify_content="flex-start", column_gap="15px")
)

controls_column = [
    HTML("<h3>GAIA Sample – Dynamic Binned Stacked Histograms</h3>"),
    HBox([file_dropdown, file_loading_status], layout=Layout(width="100%")),
    HBox([color_var_dropdown]),
    filter_header,
    *slider_rows,
    HTML("<b>Histogram Controls</b>"),
    HBox([auto_bins_checkbox]),
    HBox([min_bins_slider, max_bins_slider, global_bins_slider]),
    HBox([reset_button, status_html]),
    export_controls,
    out_plot,
    HTML("<b>Export / Preview</b>"),
    filtered_out
]

ui = VBox(controls_column)
display(ui)
draw()

VBox(children=(HTML(value='<h3>GAIA Sample – Dynamic Binned Stacked Histograms</h3>'), HBox(children=(Dropdown…