In [None]:
import numpy as np
import xarray as xr
import scipy.stats as stats
import matplotlib.pyplot as plt

In [None]:
from xarray_stats import XrContinuousRV
from bcs_ext.scipy_ext import BCCG, BCPE

In [399]:
def coord_da(vs, name):
    return xr.DataArray(vs, dims=name).assign_coords({name: lambda da: da})

In [None]:
growthstandard_keys = "ac", "bmi", "hc", "len", "ss", "ts", "wei", "wfh", "wfl"
growthstandard_dss = {
    k: xr.open_zarr(store="growthstandards.zarr", group=k, decode_times=False).load()
    for k in growthstandard_keys
}
growthstandard_dss

In [None]:
def _gds_to_rv(gds):
    return XrContinuousRV.from_ds(
        BCCG, gds.rename_vars({"m": "mu", "s": "sigma", "l": "nu"})
    )


def _gdss_to_rvs(gdss_dict):
    for k, gds in gdss_dict.items():
        if "lorh" in gds.coords:
            if k in ("wfl", "wfh"):
                yield k, _gds_to_rv(gds.drop("lorh"))
            elif k == "len":
                yield "len", _gds_to_rv(
                    gds.where(lambda ds: ds.lorh == "L", drop=True).drop("lorh")
                )
                yield "hei", _gds_to_rv(
                    gds.where(lambda ds: ds.lorh == "H", drop=True).drop("lorh")
                )
                yield "len_hei", _gds_to_rv(gds)
            else:
                yield k, _gds_to_rv(gds)
        else:
            yield k, _gds_to_rv(gds)


growthstandard_rvs = dict(_gdss_to_rvs(growthstandard_dss))
growthstandard_rvs

In [None]:
# _vs = "wei", "len", "hei", "bmi", "wfl", "wfh"
_vs = "wei", "len_hei", "bmi", "wfl", "wfh"

fig, axs = plt.subplots(len(_vs), 2, layout="constrained", figsize=(12, 5 * len(_vs)))

for raxs, v in zip(axs, _vs):
    grv = growthstandard_rvs[v]
    _y = coord_da(np.linspace(grv.ppf(0.01).min(), grv.ppf(0.99).max(), num=1000), v)
    p2d = grv.pdf(_y)
    for ax, s in zip(raxs, ("Female", "Male")):
        p2d.sel(sex=s).drop("sex_enum").plot.imshow(y=v, add_colorbar=False, ax=ax)

In [None]:
def _rtol(a, b):
    return abs(a - b) / abs(a)

In [None]:
# grv = growthstandard_rvs["len_hei"]
grv = growthstandard_rvs["len"]
print(grv.as_ds()["mu"].values)
print(grv.median().values)
print(np.allclose(grv.as_ds()["mu"], grv.median()))
print(grv.as_ds()["sigma"].values)
_grv_approx_sigma = (
    np.arcsinh(
        ((grv.ppf(0.75) - grv.ppf(0.25)) / grv.median()) / 2
    ) / stats.norm.ppf(0.75)
)
print(_grv_approx_sigma.values)
req_rtol = _rtol(grv.as_ds()["sigma"], _grv_approx_sigma).max()
print(f"Required relative tolerance: {req_rtol:e}")

In [None]:
growthstandard_rvs["wfl"].as_ds()["sigma"].plot.line(x="length")

In [None]:
def xr_bcs_mul(rv, c):
    distr = rv._distr
    rv_params = list(rv.as_ds().values())
    return XrContinuousRV(
        distr,
        rv_params[0] * c,
        *rv_params[1:]
    )


In [None]:
_q0 = 1e-4
quantiles = coord_da([_q0, 0.25, 0.5, 0.75, 1 - _q0], "quantile")
quantiles

### TODO: derive P(G_A) from P(W | L), P(W_A), P(L_A), where G = W / L, and A is indexed by age
$$
\begin{align}
A &= \text{age} \\
G &= W / L \\
p_G(g) &= \int_L |l| p_{W,L}(g l, l) \mathrm{d}l \\
&= \int_L |l| p_{W | L}(g l | l) p_L(l) \mathrm{d}l \\
&= \mathrm{E}_L[|L| p_{W | L}(g L | L)] \\
\mathrm{P}(G=g) &= \int_L |l| \mathrm{P}(W = g l | L = l) \mathrm{P}(L = l) \mathrm{d}l
\end{align}
$$

In [None]:
_len = growthstandard_rvs["wfl"].as_ds().length
_hei = growthstandard_rvs["wfh"].as_ds().height

len_rv = growthstandard_rvs["len"]
hei_rv = growthstandard_rvs["hei"]
gfl_rv = xr_bcs_mul(growthstandard_rvs["wfl"], 1 / _len)
gfh_rv = xr_bcs_mul(growthstandard_rvs["wfh"], 1 / _hei)

growth_da = coord_da(np.linspace(0.01, 0.25, 100), "growth")
# growth_da = coord_da(np.linspace(0.01, 0.25, 1_000), "growth")
_p_l = len_rv.pdf(_len)
_p_h = hei_rv.pdf(_hei)
_p_lg_l = gfl_rv.pdf(growth_da)
_p_hg_h = gfh_rv.pdf(growth_da)
p_g_l = (_p_lg_l * _p_l).integrate("length")
p_g_h = (_p_hg_h * _p_h).integrate("height")
p_g = xr.combine_by_coords([p_g_l, p_g_h])
p_g

In [None]:
p_g.integrate("growth")

In [None]:
_cdf_lg_l = gfl_rv.cdf(growth_da)
_cdf_hg_h = gfh_rv.cdf(growth_da)
cdf_g_l = (_cdf_lg_l * _p_l).integrate("length")
cdf_g_h = (_cdf_hg_h * _p_h).integrate("height")
cdf_g = xr.combine_by_coords([cdf_g_l, cdf_g_h])
cdf_g

In [None]:
def min_max(x, dim=None):
    return x.min(dim=dim), x.max(dim=dim)

In [None]:
import scipy.optimize as optimize

In [None]:
def scalar_find_ppf_as_root(q, lb, ub, x, p_x, *shape_params):
    cond_rv = BCCG(*shape_params)
    # print(q, lb, ub, x, p_x, cond_rv)
    def fun(g):
        return np.trapz(p_x * cond_rv.cdf(g), x) - q
    def dfun(g):
        return np.trapz(p_x * cond_rv.pdf(g), x)

    if fun(lb) >= 0:
        return lb
    if fun(ub) <= 0:
        return ub
    res = optimize.root_scalar(fun, fprime=dfun, bracket=(lb, ub))
    # res = optimize.root_scalar(fun, fprime=dfun, x0=(lb + ub) / 2)
    return res.root


In [None]:
def find_ppf_as_root(q, x, x_rv, cond_rv):
    cond_g_q = cond_rv.ppf(q)
    lb, ub = cond_g_q.min(x.dims), cond_g_q.max(x.dims)
    # g0 = cond_g_q.mean(x.dims)
    shape_params = list(cond_rv.as_ds().values())
    input_core_dims = [[]] * 3 + [x.dims] * (2 + len(shape_params))
    return xr.apply_ufunc(
        scalar_find_ppf_as_root,
        q,
        lb,
        ub,
        x,
        x_rv.pdf(x),
        *shape_params,
        input_core_dims=input_core_dims,
        vectorize=True
    )

In [None]:
l_distr_qs = find_ppf_as_root(quantiles, _len, len_rv, gfl_rv)
h_distr_qs = find_ppf_as_root(quantiles, _hei, hei_rv, gfh_rv)
distr_qs = xr.combine_by_coords([l_distr_qs, h_distr_qs])
distr_qs

In [None]:
fig, axs = plt.subplots(3, 2, layout="constrained", figsize=(14, 15))
p_g.isel(sex=0).plot.imshow(x="age", ax=axs[0, 0])
p_g.isel(sex=1).plot.imshow(x="age", ax=axs[0, 1])

growth_cdf = cdf_g
growth_cdf.isel(sex=0).plot.imshow(x="age", ax=axs[1, 0])
growth_cdf.isel(sex=1).plot.imshow(x="age", ax=axs[1, 1])

growth_cdf = cdf_g - p_g.cumulative_integrate("growth")
growth_cdf.isel(sex=0).plot.imshow(x="age", ax=axs[2, 0])
growth_cdf.isel(sex=1).plot.imshow(x="age", ax=axs[2, 1])
del growth_cdf

In [None]:
p_g.isel(age=0).plot(x="growth", hue="sex")

In [None]:
_p_g_x = p_g.growth.values
_dp_g_x = _p_g_x[1] - _p_g_x[0]
_p_g_bins = np.append(_p_g_x - _dp_g_x / 2, _p_g_x[-1] + _dp_g_x / 2)
_distr = stats.rv_histogram((p_g.isel(age=0, sex=0), _p_g_bins), density=True)
_distr

In [None]:
_distr_median = xr.apply_ufunc(
    lambda p: stats.rv_histogram((p, _p_g_bins), density=True).median(),
    p_g,
    input_core_dims=[["growth"]],
    vectorize=True,
)
_distr_median

In [None]:
distr_median = distr_qs.sel(quantile=0.5)
distr_median

In [None]:
_q0 = 1e-4
_q = quantiles
_distr_qs = xr.apply_ufunc(
    lambda p, q: stats.rv_histogram((p, _p_g_bins), density=True).ppf(q),
    p_g,
    _q,
    input_core_dims=[["growth"], _q.dims],
    output_core_dims=[_q.dims],
    vectorize=True,
)
_distr_qs

In [None]:
_distr_median.plot.line(x="age", hue="sex")
distr_median.plot.line(x="age", hue="sex")

In [None]:
_distr_qcv = (3/4) * (_distr_qs.sel(quantile=0.75) - _distr_qs.sel(quantile=0.25)) / _distr_median
_distr_qcv

In [None]:
distr_qcv = (3/4) * (distr_qs.sel(quantile=0.75) - distr_qs.sel(quantile=0.25)) / distr_median
distr_qcv

In [None]:
_distr_qcv.plot.line(x="age", hue="sex")

In [None]:
distr_qcv.plot.line(x="age", hue="sex")

In [None]:
_approx_sigma = np.arcsinh(_distr_qcv / 1.5) / stats.norm.ppf(0.75)
_approx_sigma

In [None]:
approx_sigma = np.arcsinh(distr_qcv / 1.5) / stats.norm.ppf(0.75)
approx_sigma

In [None]:
_approx_sigma.plot.line(x="age", hue="sex")
approx_sigma.plot.line(x="age", hue="sex")

In [None]:
_distr_mean, _distr_var, _distr_skew, _distr_kurt = xr.apply_ufunc(
    lambda p: stats.rv_histogram((p, _p_g_bins), density=True).stats(moments="mvsk"),
    p_g,
    input_core_dims=[["growth"]],
    output_core_dims=[(), (), (), ()],
    vectorize=True,
)
xr.Dataset(dict(zip("mvsk", (_distr_mean, _distr_var, _distr_skew, _distr_kurt))))

In [None]:
_distr_mean.plot.line(x="age", hue="sex")
_distr_median.plot.line(x="age", hue="sex")

In [None]:
_distr_var.plot.line(x="age", hue="sex")

In [None]:
np.sqrt(_distr_var).plot.line(x="age", hue="sex")

In [None]:
(np.sqrt(_distr_var) / _distr_median).plot.line(x="age", hue="sex")

In [None]:
(np.sqrt(_distr_var) / _distr_mean).plot.line(x="age", hue="sex")

In [None]:
# (np.sqrt(_distr_var) / _distr_median).plot.line(x="age", hue="sex")
(np.sqrt(_distr_var) / _distr_mean).plot.line(x="age", hue="sex")
_distr_qcv.plot.line(x="age", hue="sex")

In [None]:
approx_sigma.plot.line(x="age", hue="sex")
(np.sqrt(_distr_var) / _distr_mean).plot.line(x="age", hue="sex")

In [None]:
_distr_skew.plot.line(x="age", hue="sex")

In [None]:
_distr_kurt.plot.line(x="age", hue="sex")

In [None]:
p_g.isel(age=0).plot.line(x="growth", hue="sex")

In [None]:
p_g.isel(age=57).plot.line(x="growth", hue="sex")

In [None]:
N_SAMPLE = 100
_distr_sample = xr.apply_ufunc(
    lambda p: stats.rv_histogram((p, _p_g_bins), density=True).rvs(size=N_SAMPLE),
    p_g,
    input_core_dims=[["growth"]],
    output_core_dims=[["sample"]],
    vectorize=True,
)
_distr_sample

In [None]:
fig, axs = plt.subplots(2, 2, layout="constrained", figsize=(14, 12))
_distr_sample.isel(sex=0).plot.scatter(x="age", ax=axs[0, 0])
_distr_sample.isel(sex=1).plot.scatter(x="age", ax=axs[0, 1])

_trans_distr_sample = (_distr_sample / _distr_median) **(-1.511e-01)
_trans_distr_sample.isel(sex=0).plot.scatter(x="age", ax=axs[1, 0])
_trans_distr_sample.isel(sex=1).plot.scatter(x="age", ax=axs[1, 1])

In [None]:
xr.Dataset({
    "median": _trans_distr_sample.median(dim="sample"),
    "std": _trans_distr_sample.std(dim="sample"),
    "skew": _trans_distr_sample.reduce(stats.skew, dim="sample"),
})

In [None]:
import scipy.optimize as optimize

In [None]:
_sex = "Male"

_m = _distr_median.sel(sex=_sex)
_s = np.sqrt(_distr_var).sel(sex=_sex)
# _cv = _s / _m
_cv = approx_sigma.sel(sex=_sex)
_truth = p_g.sel(sex=_sex)

_norm_cost = (
    0.5 * ((_truth - xr.apply_ufunc(stats.norm.pdf, p_g.growth, _m, _s)) ** 2).sum()
)
print("Normal Fit:")
print("        cost:", _norm_cost.values)
_bccg_fit = optimize.least_squares(
    lambda x: (
        (_truth - xr.apply_ufunc(BCCG.pdf, p_g.growth, _m, _cv, *x)) ** 2
    ).values.flatten(),
    x0=-0.3,
)
print("BCCG Fit:")
print(_bccg_fit)
_bcpe_fit = optimize.least_squares(
    lambda x: (
        (_truth - xr.apply_ufunc(BCPE.pdf, p_g.growth, _m, _cv, *x)) ** 2
    ).values.flatten(),
    x0=[-0.3, 1],
)
print("BCPE Fit:")
print(_bcpe_fit)

In [None]:
fig, axs = plt.subplots(2, 1, figsize=(12, 12))

((_truth - xr.apply_ufunc(stats.norm.pdf, p_g.growth, _m, _s)) ** 2).sum(
    dim="growth"
).plot.line(ax=axs[0], label="Norm cost")
ax.legend()
((_truth - xr.apply_ufunc(BCCG.pdf, p_g.growth, _m, _cv, *_bccg_fit.x)) ** 2).sum(
    dim="growth"
).plot.line(ax=axs[1], label="BCCG cost")
((_truth - xr.apply_ufunc(BCPE.pdf, p_g.growth, _m, _cv, *_bcpe_fit.x)) ** 2).sum(
    dim="growth"
).plot.line(ax=axs[1], label="BCPE cost")
ax.legend()

In [None]:
fig, axs = plt.subplots(3, 1, figsize=(18, 12), sharex=True, layout="constrained")
_idx = 57

ax = axs[0]
_truth.isel(age=_idx).plot.line(ax=ax, x="growth", label="truth")
xr.apply_ufunc(stats.norm.pdf, p_g.growth, _m, _s).isel(age=_idx).plot.line(
    ax=ax, x="growth", ls="--", label="norm fit"
)
ax.set_title("Normal Fit")
ax = axs[1]
_truth.isel(age=_idx).plot.line(ax=ax, x="growth", label="truth")
xr.apply_ufunc(BCCG.pdf, p_g.growth, _m, _cv, *_bccg_fit.x).isel(age=_idx).plot.line(
    ax=ax, x="growth", ls="--", label="bccg fit"
)
ax.set_title("BCCG Fit")
ax = axs[2]
_truth.isel(age=_idx).plot.line(ax=ax, x="growth", label="truth")
xr.apply_ufunc(BCPE.pdf, p_g.growth, _m, _cv, *_bcpe_fit.x).isel(age=_idx).plot.line(
    ax=ax, x="growth", ls="--", label="bcpe fit"
)
ax.set_title("BCPE Fit")
_q_lims = _distr_qs.sel(sex=_sex).isel(age=_idx)
ax.set_xlim(_q_lims.isel(quantile=0), _q_lims.isel(quantile=-1))

In [None]:
_sex = "Male"

_m = _distr_median.sel(sex=_sex)
_s = np.sqrt(_distr_var).sel(sex=_sex)
# _cv = _s / _m
_cv = approx_sigma.sel(sex=_sex)
_truth = np.log(p_g.sel(sex=_sex))

_norm_cost = (
    0.5 * ((_truth - xr.apply_ufunc(stats.norm.logpdf, p_g.growth, _m, _s)) ** 2).sum()
)
print("Normal Fit:")
print("        cost:", _norm_cost.values)
_bccg_fit = optimize.least_squares(
    lambda x: (
        (_truth - xr.apply_ufunc(BCCG.logpdf, p_g.growth, _m, _cv, *x)) ** 2
    ).values.flatten(),
    x0=-0.3,
)
print("BCCG Fit:")
print(_bccg_fit)
_bcpe_fit = optimize.least_squares(
    lambda x: (
        (_truth - xr.apply_ufunc(BCPE.logpdf, p_g.growth, _m, _cv, *x)) ** 2
    ).values.flatten(),
    x0=[-0.3, 1],
)
print("BCPE Fit:")
print(_bcpe_fit)

_truth = p_g.sel(sex=_sex)

In [None]:
fig, axs = plt.subplots(2, 1, figsize=(12, 12))

((_truth - xr.apply_ufunc(stats.norm.pdf, p_g.growth, _m, _s)) ** 2).sum(
    dim="growth"
).plot.line(ax=axs[0], label="Norm cost")
ax.legend()
((_truth - xr.apply_ufunc(BCCG.pdf, p_g.growth, _m, _cv, *_bccg_fit.x)) ** 2).sum(
    dim="growth"
).plot.line(ax=axs[1], label="BCCG cost")
((_truth - xr.apply_ufunc(BCPE.pdf, p_g.growth, _m, _cv, *_bcpe_fit.x)) ** 2).sum(
    dim="growth"
).plot.line(ax=axs[1], label="BCPE cost")
ax.legend()

In [None]:
fig, axs = plt.subplots(3, 1, figsize=(18, 12), sharex=True, layout="constrained")
_idx = 100

ax = axs[0]
_truth.isel(age=_idx).plot.line(ax=ax, x="growth", label="truth")
xr.apply_ufunc(stats.norm.pdf, p_g.growth, _m, _s).isel(age=_idx).plot.line(
    ax=ax, x="growth", ls="--", label="norm fit"
)
ax.set_title("Normal Fit")
ax = axs[1]
_truth.isel(age=_idx).plot.line(ax=ax, x="growth", label="truth")
xr.apply_ufunc(BCCG.pdf, p_g.growth, _m, _cv, *_bccg_fit.x).isel(age=_idx).plot.line(
    ax=ax, x="growth", ls="--", label="bccg fit"
)
ax.set_title("BCCG Fit")
ax = axs[2]
_truth.isel(age=_idx).plot.line(ax=ax, x="growth", label="truth")
xr.apply_ufunc(BCPE.pdf, p_g.growth, _m, _cv, *_bcpe_fit.x).isel(age=_idx).plot.line(
    ax=ax, x="growth", ls="--", label="bcpe fit"
)
ax.set_title("BCPE Fit")
_q_lims = _distr_qs.sel(sex=_sex).isel(age=_idx)
ax.set_xlim(_q_lims.isel(quantile=0), _q_lims.isel(quantile=-1))