In [None]:
%pip install -Uq ipywidgets matplotlib-label-lines

In [None]:
import numpy as np
import xarray as xr
import scipy.stats as stats
import matplotlib as mpl
import matplotlib.pyplot as plt
import xarray_einstats.stats as xr_stats
import ipywidgets as widgets
from labellines import labelLines as label_lines

from growthstandards import rv_coords, rv_sel, rv_interp, GrowthStandards, XrCompoundRV
from growthstandards.bcs_ext.scipy_ext import BCCG, BCPE

In [None]:
def coord_da(vs, name, attrs={}):
    return xr.DataArray(vs, dims=name).assign_attrs(attrs).assign_coords({name: lambda da: da})

In [None]:
from scipy.special import ndtr, log_ndtr, ndtri, ndtri_exp

# TODO: use a masked lazy where

def calc_z_score(rv: xr_stats.XrRV, v: float | xr.DataArray, log: bool=False, apply_kwargs=None) -> xr.DataArray:
    coords = rv.coords
    attrs = getattr(rv, "attrs", {})
    if log:
        da = xr.where(v >= rv.median(), -ndtri_exp(rv.logsf(v, apply_kwargs=apply_kwargs)), ndtri_exp(rv.logcdf(v, apply_kwargs=apply_kwargs)))
    else:
        da = xr.where(v >= rv.median(), -ndtri(rv.sf(v, apply_kwargs=apply_kwargs)), ndtri(rv.cdf(v, apply_kwargs=apply_kwargs)))
    return da.assign_attrs(attrs).assign_coords(coords.variables)


def invert_z_score(rv: xr_stats.XrRV, z: float | xr.DataArray, log: bool=False, apply_kwargs=None) -> xr.DataArray:
    coords = rv.coords
    attrs = getattr(rv, "attrs", {})
    if log:
        da = xr.where(z >= 0, rv.isf_exp(log_ndtr(-z), apply_kwargs=apply_kwargs), rv.ppf_exp(log_ndtr(z), apply_kwargs=apply_kwargs))
    else:
        da = xr.where(z >= 0, rv.isf(ndtr(-z), apply_kwargs=apply_kwargs), rv.ppf(ndtr(z), apply_kwargs=apply_kwargs))
    return da.assign_attrs(attrs).assign_coords(coords.variables)

In [None]:
len_rv = GrowthStandards["length"]
hei_rv = GrowthStandards["height"]
gfl_rv = GrowthStandards["gfl"]
gfh_rv = GrowthStandards["gfh"]

growth_len_rv = XrCompoundRV(gfl_rv, len_rv, "length")
growth_len_rv.attrs["long_name"] = 'Growth Metric (Recumbent Length)'
growth_hei_rv = XrCompoundRV(gfh_rv, hei_rv, "height")
growth_hei_rv.attrs["long_name"] = 'Growth Metric (Standing Height)'

In [None]:
list(GrowthStandards.keys())

In [None]:
# skipped_rvs = ()
skipped_rvs = ("bmi_height", "bmi_length")
# rvs = [*(v for k,v in GrowthStandards.items() if k not in skipped_rvs), growth_len_rv, growth_hei_rv]
rvs = [*(v for k,v in GrowthStandards.items() if k not in skipped_rvs)]
for rv in rvs:
    print(rv.attrs.get("long_name"), {c: rv.coords[c].dtype if c != "sex" else rv.coords[c].values for c in rv.coords.coords})

In [None]:
z_da = coord_da([-3, -2, -1, 0, 1, 2, 3], "z")

inverted_z_scores = [
    invert_z_score(rv, z_da, apply_kwargs={"keep_attrs": True})
    for rv in rvs
]
# inverted_z_scores

In [None]:
value_methods = {
    "cdf", "logcdf", "sf", "logsf", "pdf", "logpdf", "z-score"
}
prob_methods = {
    "ppf", "isf"
}

In [None]:
widgets.interaction.show_inline_matplotlib_plots()

In [None]:
from IPython.display import Markdown

# z_c_map = {0: "C0", 2: "C1", 3: "C2"}
z_c_map = {i: f"C{i}" for i in range(4)}


def plot_inv_zscore_lines(ax, inv_zscore_da):
    lines = []
    da = inv_zscore_da
    for z in da.coords["z"]:    
        c = z_c_map[abs(int(z))]
        _lines = da.sel(z=z).plot.line(ax=ax, c=c, label=f"{int(z)}")
        lines.extend(_lines)
    label_lines(lines, fontsize=16)
    ax.autoscale(enable=True, axis="x", tight=True)
    name = da.attrs.get("long_name", da.name)
    s = da["sex"].item()
    ax.set_title(f"{name} ({s})")


def rv_widget(rv: "Union[xr_stats.XrDiscreteRV, xr_stats.XrContinuousRV, XrCompoundRV]", inv_zscore_da):
    coords = dict(rv.coords.coords.items())
    assert len(coords) == 2
    sv = coords.pop("sex")
    ((c, coord),) = coords.items()

    sex_widget = widgets.Dropdown(
        options=sv.values.tolist(),
        disabled=False,
        index=None,
    )
    sex_label = widgets.Label("Sex:")
    min = coord.min().values.item()
    max = coord.max().values.item()
    c_units = coord.attrs["units"]
    c_name = coord.attrs.get("long_name", c)
    coord_label = widgets.Label(f"{c_name} ({c_units}):")

    if coord.dtype.kind == "i" or c == "age":
        coord_widget = widgets.IntSlider(
            min=min,
            max=max,
            step=1,
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format="d",
        )
    else:
        coord_widget = widgets.FloatSlider(
            min=min,
            max=max,
            step=0.1,
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='',
        )

    value_name = rv.attrs["long_name"].split("for")[0].strip()
    value__units = rv.attrs["units"]
    value_label = widgets.Label(f"{value_name} ({value__units}):")
    value_widget = widgets.BoundedFloatText(min=0, step=0.1, value=None)
    value_box = widgets.HBox([value_label, value_widget])

    compute_widget = widgets.Button(description="Compute", icon='check')

    math_widget = widgets.HTMLMath(value="")

    out = widgets.Output()
    # out = widgets.Output(layout={'border': '1px solid black'})

    with plt.ioff():
        fig = plt.figure()
    ax = fig.gca()

    def _plot_line(*_args, point=None):
        ax.clear()
        plot_inv_zscore_lines(ax, inv_zscore_da.sel(sex=sex_widget.value))
        if point is not None:
            x, y = point
            ax.scatter(x, y, color="black", zorder=3)
        display(fig)
        # out.append_display_data(fig)
    plot_line = out.capture(clear_output=True)(_plot_line)
    # plot_line()

    def _compute(*_args):
        _rv = rv_sel(rv, {"sex": sex_widget.value})
        _rv = rv_interp(_rv, {c: coord_widget.value})
        z = calc_z_score(_rv, value_widget.value)
        if np.isfinite(z):
            z = f"{z:g}"
        elif z < 0:
            z = "-∞"
        elif z > 0:
            z = "∞"
        else:
            z = "NaN"
        math_widget.value = f"$$ Z = {z} $$"
        plot_line(point=(coord_widget.value, value_widget.value))

    def _reset(*_args):
        if math_widget.value:
            math_widget.value = ""
            plot_line()

    sex_widget.observe(plot_line, names='index')
    coord_widget.observe(_reset, names='value')
    value_widget.observe(_reset, names='value')
    compute_widget.on_click(_compute)

    sex_widget.index = 0

    return widgets.VBox([
        widgets.HBox([sex_label, sex_widget]), widgets.HBox([coord_label, coord_widget]),
        value_box, compute_widget, math_widget, out])

In [None]:
stack = widgets.Stack([rv_widget(rv, izda) for rv, izda in zip(rvs, inverted_z_scores)], selected_index=0)
dropdown = widgets.Dropdown(options=[rv.attrs["long_name"] for rv in rvs])
widgets.jslink((dropdown, 'index'), (stack, 'selected_index'))
widgets.VBox([widgets.HBox([widgets.Label("Growthstandard:"), dropdown]), stack])