# This notebook creates ...

In [None]:
# --- Panel interactivo con ipywidgets (inline con refresco en vivo) ----------
%matplotlib ipympl
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output

from PlotFuncs import FigSetup, AxionPhoton, MySaveFig, BlackHoleSpins, FilledLimit
# LaTeX (ya lo tenías así)
mpl.rcParams['text.usetex'] = False
mpl.rcParams['font.family'] = 'serif'
mpl.rcParams['font.serif'] = ['Palatino']
mpl.rcParams['figure.constrained_layout.use'] = False
# mpl.rcParams['savefig.bbox'] = 'standard'   # (optional default for all saves)

from PlotFuncs import FilledLimit, AxionPhoton, BlackHoleSpins
# Física
alpha = 1/137.035999084
K = 5.70e6
pref = alpha/(2*np.pi)
def g_agamma(m_eV, C):
    return np.abs(pref * C * (m_eV / K))

models = [
    {"name": "KSVZ", "Ndw": "1", "C": (-1.92, -1.92)},
    {"name": "DFSZ-I", "Ndw": "6,3", "C": (0.75, 0.75)},
    {"name": "DFSZ-II", "Ndw": "6,3", "C": (-1.25, -1.25)},
    {"name": "Astrophobic QCD axion", "Ndw": "1,2", "C": (-6.59, 0.74)},
    {"name": r"VISH$\nu$", "Ndw": "1", "C": (0.75, 0.75)},
    {"name": r"$\nu$DFSZ", "Ndw": "6", "C": (0.75, 0.75)},
    {"name": "Majoraxion", "Ndw": "—", "C": (2.66, 2.66)},
    {"name": "Composite Axion", "Ndw": "0/2/6", "C": (1.33, 2.66)},
]

categories = {
    "Astrophysical Bounds": [
        {"name": "Helioscopes",     "fn": AxionPhoton.Helioscopes, "visible": True},
        {"name": "White Dwarfs",   "fn": AxionPhoton.WhiteDwarfs},
        {"name": "Stellar Bounds", "fn": AxionPhoton.StellarBounds}

    ],
    "Experimental Bounds": [
        {"name": "Haloscopes",  "fn": AxionPhoton.Haloscopes},
        {"name": "Solar Basin",  "fn": AxionPhoton.SolarBasin},
        {"name": "StAB",         "fn": AxionPhoton.StAB}
    ],
    "Test QCD" : [
        {"name": "QCD Axion",    "fn": AxionPhoton.QCDAxion}
    ],
}

from IPython.display import display, clear_output
from PlotFuncs import FigSetup, MySaveFig
from PlotFuncs import AxionPhoton  # your module with all the plotting functions

# --- Interactive Axion Dashboard (shared fig & ax, inline) ---
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output
from PlotFuncs import FigSetup, MySaveFig, AxionPhoton

def html_label(html, width="160px"):
    lab = widgets.HTML(value=html)
    lab.layout.width = width
    return lab

mmin_label = html_label("m<sub>a</sub><sup>min</sup> [eV]")
mmax_label = html_label("m<sub>a</sub><sup>max</sup> [eV]")
ymin_label = html_label("g<sub>aγ</sub><sup>min</sup>")
ymax_label = html_label("g<sub>aγ</sub><sup>max</sup>")


# === Define your shared figure at the top ===
fig, ax = FigSetup(Shape='Rectangular', ylab=r'$|g_{a\gamma}|$ [GeV$^{-1}$]', mathpazo=True)
plt.close(fig)  # prevent static plot in notebook


def interactive_axion_dashboard(models, categories=None, fig=None, ax=None,
                                title="Axion–Photon Coupling vs Mass — interactive"):
    """
    Interactive dashboard that plots both models and bounds on the same shared ax.
    """

    if fig is None or ax is None:
        raise ValueError("You must pass the shared fig and ax (created with FigSetup).")

    # === Physics ===
    alpha = 1/137.035999084
    K = 5.70e6
    pref = alpha/(2*np.pi)
    def g_agamma(m_eV, C): return np.abs(pref * C * (m_eV / K))

    # === Widgets ===
    #mmin = widgets.FloatLogSlider(value=1e-12, base=10, min=-15, max=8, description=r'm_a^{\mathrm{min}} [eV]', readout_format='.1e')
    #mmax = widgets.FloatLogSlider(value=1e-1,  base=10, min=-15, max=8, description=r'm_a^{\mathrm{max}} [eV]', readout_format='.1e')
    #ymin = widgets.FloatLogSlider(value=1e-20, base=10, min=-30, max=-5, description=r'$g_{a\gamma}^{\mathrm{min}}$', readout_format='.1e')
    #ymax = widgets.FloatLogSlider(value=1e-8,  base=10, min=-30, max=-5, description=r'$g_{a\gamma}^{\mathrm{max}}$', readout_format='.1e')
    #save_btn = widgets.Button(description='Guardar figura', button_style='success')

    def latex_label(tex, fallback_html):
        """Try LaTeX via HTMLMath; if it renders blank in your setup, use plain HTML."""
        try:
            lab = widgets.HTMLMath(value=tex)
            # give it a width so it’s visible/aligns nicely
            lab.layout.width = "160px"
            return lab
        except Exception:
            lab = widgets.HTML(value=fallback_html)
            lab.layout.width = "160px"
            return lab

    mmin = widgets.FloatLogSlider(value=1e-12, base=10, min=-15, max=8, step=0.1, readout_format=".1e")
    mmax = widgets.FloatLogSlider(value=1e-1,  base=10, min=-15, max=8, step=0.1, readout_format=".1e")
    ymin = widgets.FloatLogSlider(value=1e-20, base=10, min=-30, max=-5, step=0.1, readout_format=".1e")
    ymax = widgets.FloatLogSlider(value=1e-8,  base=10, min=-30, max=-5, step=0.1, readout_format=".1e")

    mmin_box = widgets.HBox([mmin_label, mmin])
    mmax_box = widgets.HBox([mmax_label, mmax])
    ymin_box = widgets.HBox([ymin_label, ymin])
    ymax_box = widgets.HBox([ymax_label, ymax])


    mmin.layout.width = mmax.layout.width = ymin.layout.width = ymax.layout.width = "260px"
    save_btn = widgets.Button(description='Guardar figura', button_style='success')
    sel_all  = widgets.Button(description='Seleccionar todo')
    sel_none = widgets.Button(description='Deseleccionar todo')
    # --- Tab 0: Modelos (con botones propios) ---
    model_checks = [widgets.Checkbox(value=True, description=m["name"]) for m in models]
    sel_all_models  = widgets.Button(description='Seleccionar todo')
    sel_none_models = widgets.Button(description='Deseleccionar todo')

    modelos_panel = widgets.VBox([
        widgets.VBox(model_checks, layout=widgets.Layout(min_width='260px', max_height='350px', overflow='auto')),
        widgets.HBox([sel_all_models, sel_none_models]),
    ])

    # --- Tabs para categorías: cada una con sus propios botones ---
    tabs_children = [modelos_panel]
    tab_titles = ['Modelos']

    # guardamos checks y botones por categoría
    cat_checkgroups = {}  # {tab_name: {"checks":[...], "items":[...], "sel_all":btn, "sel_none":btn}}
    if categories:
        for tab_name, items in categories.items():
            checks = [widgets.Checkbox(value=bool(it.get("visible", False)), description=it["name"]) for it in items]
            sel_all_btn  = widgets.Button(description='Seleccionar todo')
            sel_none_btn = widgets.Button(description='Deseleccionar todo')

            panel = widgets.VBox([
                widgets.VBox(checks, layout=widgets.Layout(min_width='260px', max_height='350px', overflow='auto')),
                widgets.HBox([sel_all_btn, sel_none_btn]),
            ])

            tabs_children.append(panel)
            tab_titles.append(tab_name)
            cat_checkgroups[tab_name] = {"checks": checks, "items": items,
                                        "sel_all": sel_all_btn, "sel_none": sel_none_btn}

    tabs = widgets.Tab(children=tabs_children)
    for i, t in enumerate(tab_titles):
        tabs.set_title(i, t)

    # --- Panel derecho: sliders y Guardar figura AL FONDO ---
    save_btn = widgets.Button(description='Guardar figura', button_style='success')
    right = widgets.VBox([mmin_box, mmax_box, ymin_box, ymax_box, save_btn])
    ui = widgets.HBox([tabs, right])
    display(ui)

    # --- Output de la figura ---
    out = widgets.Output()
    display(out)


        # --- Redraw function (everything drawn on same ax) ---
    def redraw(*_):
        import numpy as np, matplotlib.pyplot as plt, traceback
        from IPython.display import clear_output, display

        # -- reset canvas --
        ax.cla()
        ax.set_xscale('log'); ax.set_yscale('log')
        ax.set_xlabel(r"$m_a$ [eV]")
        ax.set_ylabel(r"$|g_{a\gamma}|$ [GeV$^{-1}$]")
        xlims = (mmin.value, mmax.value)
        ylims = (ymin.value, ymax.value)
        ax.set_xlim(*xlims); ax.set_ylim(*ylims)

        # -- MODEL LINES/BANDS --
        m_grid = np.logspace(np.log10(xlims[0]), np.log10(xlims[1]), 600)
        for chk, md in zip(model_checks, models):
            if not chk.value:
                continue
            cmin, cmax = md["C"]
            ndw = md.get("Ndw", "—")
            if np.isclose(cmin, cmax):
                yy = g_agamma(m_grid, cmin)
                ax.plot(m_grid, yy, lw=2, alpha=0.95,
                        label=rf"{md['name']} ($N_{{\rm dw}}$={ndw})")
            else:
                y1 = g_agamma(m_grid, cmin); y2 = g_agamma(m_grid, cmax)
                ylo, yhi = np.minimum(y1, y2), np.maximum(y1, y2)
                ax.fill_between(m_grid, ylo, yhi, alpha=0.25)
                ax.plot(m_grid, np.sqrt(ylo*yhi), lw=1.5, alpha=0.85,
                        label=rf"{md['name']} ($N_{{\rm dw}}$={ndw})")

        # -- helper: force-plot on our ax, show any errors, and close stray figs --
        def _call_on_ax_and_close_new_figs(name, fn, ax, kwargs):
            import matplotlib.pyplot as plt
            prev_figs = set(plt.get_fignums())

            # Freeze current view & autoscale state
            xlim = ax.get_xlim()
            ylim = ax.get_ylim()
            was_auto = ax.get_autoscale_on()
            ax.set_autoscale_on(False)      # <- prevent relim/autoscale_view from changing limits
            plt.sca(ax)                     # <- make sure gca() points to our axes

            try:
                try:
                    fn(ax, **(kwargs or {}))      # prefer positional ax
                except TypeError:
                    fn(ax=ax, **(kwargs or {}))   # fallback to keyword ax
            except Exception as e:
                print(f"[ERROR] {name} raised: {e}")
                import traceback; traceback.print_exc()
            finally:
                # Restore limits and autoscale state
                ax.set_xlim(xlim); ax.set_ylim(ylim)
                ax.set_autoscale_on(was_auto)

                # Close any figures created by the call
                for num in (set(plt.get_fignums()) - prev_figs):
                    plt.close(num)


        # -- BOUNDS (per tabs) --
        if categories:
            for tab_name, grp in cat_checkgroups.items():
                checks, items = grp["checks"], grp["items"]
                for chk, it in zip(checks, items):
                    if not chk.value:
                        continue
                    _call_on_ax_and_close_new_figs(it["name"], it["fn"], ax, it.get("kwargs", {}))


        # Restore limits in case any bound function changed them
        # Restore limits in case any bound function changed them
        ax.set_xlim(*xlims); ax.set_ylim(*ylims)

        # --- prevent external labels from shrinking the axes ---
        fig.set_constrained_layout(False)  # ensure constrained layout is off
        for t in ax.texts:
            if hasattr(t, "set_in_layout"):
                t.set_in_layout(False)      # exclude from tight/constrained layout
            t.set_clip_on(True)              # keep them from affecting data box

        # (optional) hide bound labels during interaction
        # for t in ax.texts:
        #     t.set_visible(False)

        # -- cosmetics + render --
        ax.grid(True, which='both', ls=':', alpha=0.4)
        ax.legend(loc='lower right', fontsize=9, frameon=False)
        ax.set_title(title, pad=8)

        with out:
            clear_output(wait=True)
            display(fig)


    # --- Event handlers ---
    def on_save(_):
        try:
            fig.savefig("AxionPhoton_Dashboard_Shared.pdf", bbox_inches='standard')
            fig.savefig("AxionPhoton_Dashboard_Shared.png", dpi=300, bbox_inches='standard')
        except Exception as e:
            print("Save failed:", e)

    save_btn.on_click(on_save)

    # Modelos: select/deselect
    sel_all_models.on_click(lambda _: [setattr(c, "value", True)  for c in model_checks])
    sel_none_models.on_click(lambda _: [setattr(c, "value", False) for c in model_checks])

    # Categorías: select/deselect por pestaña
    for grp in cat_checkgroups.values():
        grp["sel_all"].on_click(lambda _, g=grp: [setattr(c, "value", True)  for c in g["checks"]])
        grp["sel_none"].on_click(lambda _, g=grp: [setattr(c, "value", False) for c in g["checks"]])

    # Redraw on changes
    for w in (mmin, mmax, ymin, ymax): w.observe(redraw, names='value')
    for chk in model_checks: chk.observe(redraw, names='value')
    for grp in cat_checkgroups.values():
        for chk in grp["checks"]:
            chk.observe(redraw, names='value')


    redraw()

interactive_axion_dashboard(models, categories=categories, fig=fig, ax=ax)



HBox(children=(Tab(children=(VBox(children=(VBox(children=(Checkbox(value=True, description='KSVZ'), Checkbox(…

Output()