In [None]:
import numpy as np
from jax import numpy as jnp

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from IPython.display import display

import astropy.units as u
from astropy.table import Table
from astropy.visualization import simple_norm
from ipywidgets import (
    interactive, FloatSlider, HBox, VBox, Button, 
    Layout, Box, Checkbox, Output, IntSlider
)

from cs_sample.jax import target_cost
from cs_sample.core import (
    download_sheet, priority_from_cs_distance
)

In [None]:
sheet = download_sheet()

columns = [
    'Teff', 'Kmag', 'Rp/Rs', 'a/Rs', 'Eclipse Dur', 
    'Instellation', 'Escape Velocity', 'Has < 20% mass constraint?',
    'Rp', 'Mp', 'XUV Instellation', 'Teq', '1 eclipse depth precision'
]
(
    teff, kmag, rp_rs, aRs, eclipse_dur, instellation, 
    v_esc, mass_constraint, Rp, Mp, xuv, Teq, one_eclipse_precision_hdl
) = jnp.array(
    sheet[columns].to_numpy().T
)

names = np.array([t.split('(')[0].strip() for t in sheet['Planet name']])

priority, x, y = priority_from_cs_distance(v_esc, instellation)

rho_earth = u.def_unit('rho_earth', 1 * u.M_earth / (4/3 * np.pi * (1 * u.R_earth)**3))
density = (Mp * u.M_earth / (4/3 * np.pi * (Rp * u.R_earth)**3)).to_value(rho_earth)

in_go_programs = np.isin(
    sheet['Planet name'].tolist(), 
    ['LP 791-18 d', 'TRAPPIST-1 b', 'TRAPPIST-1 c']
)

in_hot_rocks = np.array([
    "Hot Rocks" in comment if isinstance(comment, str) else False 
    for comment in sheet['General comments'].tolist()
])

In [None]:
norm = simple_norm(priority, 'linear', min_cut=-0.1, max_cut=priority.max())

scenarios = dict(
    mercury_vs_venus=dict(
        eps_max=1.0, # perfect redist
        AB_min=0.119, # Mercury
        AB_max=0.75,  # Venus
        n_sigma=3
    ),
    
    mars_vs_mercury=dict(
        eps_max=0.04, # Mars
        AB_min=0.119, # Mercury
        AB_max=0.16,  # Mars
        n_sigma=3
    ),

    mercury_vs_earth=dict(
        eps_max=1.0, # perfect redist
        AB_min=0.119, # Mercury
        AB_max=0.29,  # Earth
        n_sigma=3
    ),

    TUC_v1=dict(
        eps_max=1,
        AB_min=0.1,
        AB_max=0.3,
        n_sigma=4,
        teff_min=2500,
        teff_max=4000,
    )
)

In [None]:

table_output = Output()


def plot(
    eps_max, AB_min, AB_max, n_sigma, 
    teff_min, teff_max, teq_max, max_hrs,
    noise_excess,
    include_go=False, 
    include_hot_rocks=False, 
    include_imprecise_mass=False,
    use_xuv=False
):
    mask_keys = 'teff aRs rp_rs K_mag eclipse_dur'.split()
    mask = np.zeros(len(teff)).astype(bool)
    
    # exclude outside of temperature range, exclude imprecise masses:
    mask = ~np.array((teff_min < teff) & (teff < teff_max))

    mask |= ~(Teq < teq_max)
    
    if not include_go:
        mask |= in_go_programs

    if not include_hot_rocks:
        mask |= in_hot_rocks

    if not include_imprecise_mass:
        mask |= ~mass_constraint.astype(bool)

    cost, sort_order = target_cost(
        teff=teff, aRs=aRs, AB_min=AB_min, AB_max=AB_max, 
        eps_max=eps_max, rp_rs=rp_rs, K_mag=kmag, 
        n_sigma=n_sigma, eclipse_dur=eclipse_dur, 
        one_eclipse_precision_hdl=one_eclipse_precision_hdl,
        photon_noise_excess=noise_excess
    )

    sort = jnp.argsort(cost[~mask])
    last_index = jnp.searchsorted(
        np.cumsum(cost[~mask][sort]), 
        max_hrs
    )
    sheet_mask = np.arange(len(cost))[~mask][sort][:last_index]

    fig = plt.figure(figsize=(7, 7), dpi=150)
    gs = GridSpec(5, 5, figure=fig)

    ax_hist = [fig.add_subplot(gs[0, i]) for i in range(gs.ncols)]
    ax = fig.add_subplot(gs[1:, :])
    if use_xuv:
        # from Zahnle & Catling 2017
        yi = 1e-6 / (0.18 ** 4) * x ** 4
        ax.loglog(x, yi, lw=3, color='silver', zorder=-100, alpha=0.5)
    else:
        ax.loglog(x, y, lw=3, color='silver', zorder=-100, alpha=0.5)

    plot_instell = xuv if use_xuv else instellation
    ax.scatter(
        v_esc, plot_instell, 
        edgecolor='none', 
        color='silver', 
        alpha=0.3
    )
    cax = ax.scatter(
        v_esc[sheet_mask], 
        plot_instell[sheet_mask], 
        c=priority[sheet_mask],
        edgecolor='none', 
        norm=norm
    )
    for i, (xi, yi) in enumerate(zip(v_esc[sheet_mask], plot_instell[sheet_mask])):
        ax.annotate(f' {i}', (xi, yi), ha='left', va='bottom', fontsize=8)
    
    plt.colorbar(cax, ax=ax, label='priority')
        
    ax.set(
        xlabel='$v_{\\rm esc}$ [km s$^{-1}$]',
        ylabel=('XUV ' if use_xuv else '') + 'Instellation [I$_{\odot}$]',
        xscale='log', 
        yscale='log',
    )
    table_contents = {
        'target': names[sheet_mask],
        'cost [hr]': cost[sheet_mask],
        'priority': priority[sheet_mask],
        '$\\rho$ [$\\rho_\\odot$]': density[sheet_mask],
        'Teq': Teq[sheet_mask],
        '$v_{\\rm esc}$ [km/s]': v_esc[sheet_mask],
    }

    if use_xuv:
        table_contents['$I_{\\rm XUV}$ [$I_\odot$]'] = xuv[sheet_mask]
    else:
        table_contents['I [$I_\odot$]'] = instellation[sheet_mask]
        
    
    mask_cols = ['GO', 'HotRocks', 'ImpMass']
    for toggle, mask, hdr in zip(
        [include_go, include_hot_rocks, include_imprecise_mass],
        [in_go_programs, in_hot_rocks, ~mass_constraint.astype(bool)],
        mask_cols
    ):
        if toggle:
            table_contents[hdr] = np.where(mask[sheet_mask], '❌', '')
    
    target_table = Table(table_contents)

    for col in target_table.colnames[1:]:
        if col not in mask_cols:
            target_table[col].format = '0.1f'
    notes = (
        f'N$_{{\\rm targets}}$ = {len(sheet_mask)}\n' + 
        f'Total obs time = {cost[sheet_mask].sum():.0f} hrs\n'
    )
    ax.annotate(
        notes, (0.05, 0.95), 
        xycoords='axes fraction', 
        va='top', ha='left'
    )

    labels = ['$T_{\\rm eff}$ [K]', '$t_{\\rm obs}$ [hrs]', 'priority', '$\\rho$ [$\\rho_\\odot$]', 'log XUV']
    for i, (parameter, label) in enumerate(zip(
        [teff, cost, priority, density, np.where(xuv < 1e4, np.log10(xuv), np.nan)], 
        labels
    )):
        if i not in [1]:
            n, bins = ax_hist[i].hist(
                parameter, alpha=0.2, color='silver'
            )[:2]
        else:
            bins = None

        ax_hist[i].hist(
            parameter[sheet_mask], color='C0', bins=bins
        )
        ax_hist[i].set(
            xlabel=label
        )
        if any(s in label for s in ['T_{', 'priority', 'rho', 'XUV']):
            ax_hist[i].set_yscale('log')
        if any(s in label for s in ['rho']):
            ax_hist[i].set_xscale('log')

    fig.tight_layout()
    plt.show()

    with table_output:
        table_output.clear_output()
        display(target_table.show_in_notebook(display_length=-1, show_row_index='marker'))
        
widget = interactive(
    plot,
    eps_max=FloatSlider(
        value=1, min=0, max=1, step=0.05, 
        tooltip='Redist. efficiency'
    ), 
    AB_min=FloatSlider(
        value=0.1, min=0, max=1, step=0.05, 
        tooltip='Bond albedo no redist.'
    ), 
    AB_max=FloatSlider(
        value=0.3, min=0, max=1, step=0.05, 
        tooltip='Bond albedo with redist.'
    ), 
    n_sigma=FloatSlider(
        value=4, min=3, max=6, step=0.1, 
        tooltip='Require detection above N sigma'
    ),
    teff_min=IntSlider(
        value=3200, min=2500, max=4000, step=10, 
        tooltip='Minimum stellar T_eff'
    ),
    teff_max=IntSlider(
        value=3800, min=2500, max=4000, step=10, 
        tooltip='Maximum stellar T_eff'
    ),
    teq_max=IntSlider(
        value=600, min=200, max=1000, step=10, 
        tooltip='Max planetary equilibrium temperature'
    ),
    max_hrs=IntSlider(
        value=500, min=500, max=700, step=10, 
        tooltip='Max total obs. hours'
    ),
    noise_excess=FloatSlider(
        value=1, min=0, max=3, step=0.1, 
        tooltip='From photon noise (0), to scaled T-1c precision (1), and beyond (>1)'
    ), 
    include_go=Checkbox(
        value=False, description='Include GO', icon='check', 
        tooltip='Include JWST GO targets'
    ),
    include_hot_rocks=Checkbox(
        value=False, description='Include Hot Rocks', 
        tooltip='Include Hot Rocks targets'
    ), 
    include_imprecise_mass=Checkbox(
        value=False, description='Include imprecise masses', 
        tooltip='Include targets with mass precision <20%'
    ), 
    use_xuv=Checkbox(
        value=False, description='Show XUV instel.', 
    ), 
)


def select_scenario(button):
    scenario = button.description
    for child in widget.children:
        if getattr(child, 'description', None) in scenarios[scenario].keys():
            child.value = scenarios[scenario][child.description]

buttons = []
for scenario in scenarios.keys():
    button = Button(description=scenario)
    button.on_click(select_scenario)
    buttons.append(button)

vbox = VBox((VBox(buttons), VBox(widget.children[9:-1])))
vbox.layout.align_items = 'center'

controls = HBox([VBox(widget.children[:9]), vbox])
output = widget.children[-1]
box = Box([output, VBox([controls, table_output])])
box

# notes

## feature requests

* [X] toggle for requirement $200 < T_{\rm eq} < 600$ K
* [X] slider for max program limit
* [ ] checkbox dropdown for required targets
* [X] "cost optimism": currently scaled from trappist-1 c, but we could do better on some targets. allow user to vary from Hannah's to Néstor's eclipse precision
* [ ] "is this a rock?" mass precision -> slider?
* [X] toggle instellation vs. XUV instellation y-axis

# user profiles
* Néstor, Hannah, Brett
* concern about "is this a rock?" 

In [None]:
#http://hyperphysics.phy-astr.gsu.edu/hbase/phyopt/albedo.html

ss_bond_albedos = [
    line.split() for line in """Mercury 0.119
Venus 0.75
Earth 0.29
Moon 0.123
Mars 0.16
Pluto 0.4""".splitlines()]

In [None]:
# https://ui.adsabs.harvard.edu/abs/2023prmw.conf..174V/abstract

T_day_mercury = 623 * u.K
T_night_mercury = 103 * u.K

T_day_mars = 300 * u.K
T_night_mars = 140 * u.K

T_day_venus = 737 * u.K
T_night_venus = 737 * u.K
# https://www.aeronomie.be/en/encyclopedia/mars-climate-important-temperature-difference-between-day-and-night
epsilons = []
for temps in [[T_day_mercury, T_night_mercury], [T_day_venus, T_night_venus], [T_day_mars, T_night_mars]]:
    Tday, Tnight = temps

    eps = Tnight ** 4 / Tday ** 4
    epsilons.append(eps)

eps_mercury, eps_venus, eps_mars = epsilons
ABs = AB_mercury, AB_venus, AB_mars = [0.119, 0.75, 0.16]

plt.plot(ABs, epsilons, 'o')
plt.gca().set(
    xlim=[-0.1, 1.1],
    ylim=[-0.1, 1.1],
    xlabel='$A_{\\rm B}$',
    ylabel='$\\varepsilon$'
)