In [None]:
import numpy as np
import xarray as xr
import scipy.stats as stats
import matplotlib.pyplot as plt

In [None]:
from growthstandards import rv_to_ds, GrowthStandards

In [None]:
def coord_da(vs, name):
    return xr.DataArray(vs, dims=name).assign_coords({name: lambda da: da})

In [None]:
_vs = "weight", "len_hei", "bmi", "wfl", "wfh"

fig, axs = plt.subplots(len(_vs), 2, layout="constrained", figsize=(12, 5 * len(_vs)))

for raxs, v in zip(axs, _vs):
    grv = GrowthStandards[v]
    _y = coord_da(np.linspace(grv.ppf(0.01).min(), grv.ppf(0.99).max(), num=1000), v, getattr(grv, "attrs", {}))
    p2d = grv.pdf(_y, apply_kwargs=dict(keep_attrs="drop_conflicts"))
    for ax, s in zip(raxs, ("Female", "Male")):
        p2d.sel(sex=s).drop("sex_enum").plot.imshow(y=v, add_colorbar=False, ax=ax)

### TODO: derive P(G_A) from P(W | L), P(W_A), P(L_A), where G = W / L, and A is indexed by age
$$
\begin{align}
A &= \text{age} \\
G &= W / L \\
p_G(g) &= \int_L |l| p_{W,L}(g l, l) \mathrm{d}l \\
&= \int_L |l| p_{W | L}(g l | l) p_L(l) \mathrm{d}l \\
&= \mathrm{E}_L[|L| p_{W | L}(g L | L)] \\
\mathrm{P}(G=g) &= \int_L |l| \mathrm{P}(W = g l | L = l) \mathrm{P}(L = l) \mathrm{d}l
\end{align}
$$

In [None]:
import jax.numpy as jnp
from jax import grad
from jax import jit
from jax import random
from jax import value_and_grad
from jax import vmap
from jax.random import PRNGKey

In [None]:
from tensorflow_probability.substrates import jax as tfp

tfd = tfp.distributions
tfb = tfp.bijectors
tfpk = tfp.math.psd_kernels

In [None]:
# _sex = "Male"
_len_ds = rv_to_ds(GrowthStandards["length"])  # .sel(sex=_sex)
display(_len_ds)
_hei_ds = rv_to_ds(GrowthStandards["height"])  # .sel(sex=_sex)
display(_hei_ds)
_wfl_ds = rv_to_ds(GrowthStandards["wfl"])  # .sel(sex=_sex)
display(_wfl_ds)
_wfh_ds = rv_to_ds(GrowthStandards["wfh"])  # .sel(sex=_sex)
display(_wfh_ds)

In [None]:
import jax
import optax

In [None]:
from growthstandards.bcs_ext.tfp_jax_ext import BoxCoxColeGreen, BoxCoxPowerExponential, boxcox, inv_boxcox

In [None]:
def bccg_from_ds(ds):
    return BoxCoxColeGreen(
        ds["mu"].values.astype(np.float32),
        ds["sigma"].values.astype(np.float32),
        ds["nu"].values.astype(np.float32),
    )

In [None]:
# Batched over sex, age
L = bccg_from_ds(_len_ds)
L

In [None]:
L.loc, L.quantile(0.5)

In [None]:
print(L.scale)
_L_approx_scale = jnp.arcsinh(
    ((L.quantile(0.75) - L.quantile(0.25)) / L.loc) / 2
) / stats.norm.ppf(0.75)
print(_L_approx_scale)
np.allclose(L.scale, _L_approx_scale, rtol=1e-3)

In [None]:
_z = xr.DataArray(
    boxcox(L.sample(100, seed=PRNGKey(0)) / L.loc, L.nu) / L.scale,
    dims=("sample", "sex", "age"),
    coords={"sex": _len_ds["sex"], "age": _len_ds["age"]}
)
_z

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(12, 5), layout="constrained", sharey=True)
_z.isel(sex=0).plot.scatter(ax=axs[0], x="age")
_z.isel(sex=1).plot.scatter(ax=axs[1], x="age")

In [None]:
_wfl_length = _wfl_ds["length"].values.astype(np.float32)
# _wfl_mu = _wfl_ds["mu"].values.astype(np.float32)
# _wfl_sigma = _wfl_ds["sigma"].values.astype(np.float32)
# _wfl_nu = _wfl_ds["nu"].values.astype(np.float32)  # [..., 0]

# Batched over sex, length
batched_WFL = bccg_from_ds(_wfl_ds)

# WFL = tfd.MixtureSameFamily(
#     tfd.Categorical(logits=L[..., None].log_prob(_wfl_length)),
#     batched_WFL,
# )

_batched_G = tfd.TransformedDistribution(batched_WFL, tfb.Scale(1 / _wfl_length))
batched_G = BoxCoxColeGreen(batched_WFL.loc / _wfl_length, batched_WFL.scale, batched_WFL.nu)
assert np.allclose(batched_G.mean(), _batched_G.mean())
assert np.allclose(batched_G.variance(), _batched_G.variance())

G = tfd.MixtureSameFamily(
    tfd.Categorical(logits=L[..., None].log_prob(_wfl_length)),
    batched_G[..., None, :],
)
G

In [None]:
batched_WFL.loc, batched_WFL.quantile(0.5)

In [None]:
batched_WFL.scale, jnp.arcsinh(
    ((batched_WFL.quantile(0.75) - batched_WFL.quantile(0.25)) / batched_WFL.loc) / 2
    ) / stats.norm.ppf(0.75)

In [None]:
_sampled_wfl = xr.DataArray(
    batched_WFL.sample(100, seed=PRNGKey(0)),
    dims=("sample", "sex", "length"),
    coords={"sex": _wfl_ds["sex"], "age": _wfl_ds["length"]}
)
_sampled_wfl

In [None]:
_sampled_wfl.reduce(stats.skew, dim="sample")

In [None]:
_z = xr.apply_ufunc(lambda w: boxcox(w / batched_WFL.loc, batched_WFL.nu) / batched_WFL.scale, _sampled_wfl)
_z

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(12, 10), layout="constrained", sharey="row")
# _sampled_wfl.isel(sex=0).plot.scatter(ax=axs[0, 0], x="length")
# _sampled_wfl.isel(sex=1).plot.scatter(ax=axs[0, 1], x="length")

_v = (_sampled_wfl / batched_WFL.loc) ** batched_WFL.nu
_v.isel(sex=0).plot.scatter(ax=axs[0, 0], x="length")
_v.isel(sex=1).plot.scatter(ax=axs[0, 1], x="length")

_z.isel(sex=0).plot.scatter(ax=axs[1, 0], x="length")
_z.isel(sex=1).plot.scatter(ax=axs[1, 1], x="length")
xr.Dataset({"mean": _z.mean(dim="sample"), "std": _z.std(dim="sample")})

In [None]:
(batched_WFL.scale * abs(batched_WFL.nu)) / _v.std(dim="sample")

In [None]:
_v.reduce(stats.skew, dim="sample")

In [None]:
from tensorflow_probability.python.internal.backend.jax.numpy_math import divide_no_nan

In [None]:
from growthstandards.bcs_ext.tfp_jax_ext import same_family_mixture_quantile

In [None]:
g_median = same_family_mixture_quantile(G, 0.5)
g_median

In [None]:
g_qcv = (3/4) * (same_family_mixture_quantile(G, 0.75) - same_family_mixture_quantile(G, 0.25)) / g_median
g_qcv

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(12, 5), layout="constrained", sharey=True)
axs[0].plot(g_qcv[0])
axs[1].plot(g_qcv[1])

In [None]:
g_approx_sigma = jnp.arcsinh(g_qcv / 1.5) / stats.norm.ppf(0.75)
g_approx_sigma

In [None]:
_q0 = 1e-4
_lb = same_family_mixture_quantile(G, _q0).min() #.min(axis=-1)
_ub = same_family_mixture_quantile(G, 1 - _q0).max() #.max(axis=-1)
growth_da = coord_da(np.linspace(_lb, _ub, 1_000), "growth")

In [None]:
p_g = xr.apply_ufunc(
    G.prob, growth_da, output_core_dims=[["sex", "age"]], vectorize=True
).assign_coords({"sex": _len_ds["sex"], "age": _len_ds["age"]})
p_g.shape

In [None]:
G.components_distribution.scale

In [None]:
G.components_distribution.variance()

In [None]:
G.components_distribution.scale.sum(axis=-1)

In [None]:
G.components_distribution.scale

In [None]:
G.components_distribution.nu

In [None]:
g_approx_sigma

In [None]:
_G_part_mix_prob = G.mixture_distribution.probs_parameter()
_G_parts = BoxCoxColeGreen(
    G.components_distribution.loc * _G_part_mix_prob,
    G.components_distribution.scale,
    G.components_distribution.nu,
)
_G_parts

In [None]:
G.mean()

In [None]:
g_median, _G_parts.quantile(0.5).sum(axis=-1)

In [None]:
assert np.allclose(g_median, _G_parts.quantile(0.5).sum(axis=-1), rtol=1e-2)

In [None]:
g_75 = same_family_mixture_quantile(G, 0.75)
g_75

In [None]:
_G_parts.quantile(0.75).sum(axis=-1)

In [None]:
np.allclose(g_75, _G_parts.quantile(0.75).sum(axis=-1), rtol=5e-2)

In [None]:
np.allclose(g_75, _G_parts.quantile(0.75).sum(axis=-1), rtol=1e-2)

In [None]:
g_qcv

In [None]:
# _g_qcv = (3/4) * (_G_parts.quantile(0.75) - _G_parts.quantile(0.25)).sum(axis=-1) / _G_parts.loc.sum(axis=-1)
_g_qcv = (3/4) * divide_no_nan((_G_parts.quantile(0.75) - _G_parts.quantile(0.25)), _G_parts.loc)
_g_qcv.sum(axis=-1)

In [None]:
_g_qcv = (3/4) * (_G_parts.quantile(0.75).sum(axis=-1) - _G_parts.quantile(0.25).sum(axis=-1)) / _G_parts.loc.sum(axis=-1)
_g_qcv

In [None]:
g_approx_sigma

In [None]:
(_G_parts.scale).sum(axis=-1) / 731

In [None]:
(_G_parts.scale * G.mixture_distribution.probs_parameter()).sum(axis=-1)

In [None]:
_G = BoxCoxColeGreen(
    _G_parts.loc.sum(axis=-1),
    # jnp.sqrt(_G_parts.variance().sum(axis=-1)) / _G_parts.loc.sum(axis=-1),
    g_approx_sigma,
    # all values equal on last axis
    _G_parts.nu[..., 0],
)
_G

In [None]:
# np.allclose(same_family_mixture_quantile(G, 0.5), _G.quantile(0.5), rtol=1e-2)
np.allclose(g_median, _G.quantile(0.5), rtol=1e-2)


In [None]:
np.allclose(same_family_mixture_quantile(G, 0.75), _G.quantile(0.75), rtol=1e-2)

In [None]:
(G.mean(), _G.mean())

In [None]:
(G.variance(), _G.variance())

### Fitting

In [None]:
Root = tfd.JointDistributionCoroutine.Root

@tfd.JointDistributionCoroutine
def test_model():
    scale = yield Root(tfd.InverseGaussian(
        jnp.mean(g_approx_sigma, axis=-1),
        (jnp.mean(g_approx_sigma, axis=-1)**3) / jnp.var(g_approx_sigma, axis=-1),
        name="sigma"
    )[:, None])
    nu = yield (tfd.Normal(
        jnp.array([-0.2, -0.2]),
        jnp.array([0.1, 0.1]),
        name="lmbda"
    )[:, None])
    g = yield BoxCoxColeGreen(
        g_median,
        scale,
        nu,
        name="growth"
    )
test_model

In [None]:
seed = PRNGKey(0)
init_seed, seed = jax.random.split(seed)

In [None]:
init_fn, build_fn = tfp.experimental.vi.build_affine_surrogate_posterior_from_base_distribution_stateless(
    test_model,
    operators="tril",
)
initial_parameters = init_fn(seed=init_seed)
initial_parameters

In [None]:
g_approx_sigma.mean(axis=-1)

In [None]:
G.batch_shape + G.event_shape

In [None]:
independent_G = tfd.Independent(G, 1)

init_fn, build_fn = tfp.experimental.util.make_trainable_stateless(
    # tfd.Normal,
    # tfd.MultivariateNormalDiag,
    BoxCoxColeGreen,
    # BoxCoxPowerExponential,
    # initial_parameters=dict(loc=g_median),
    # initial_parameters=dict(scale=g_scale),
    initial_parameters=dict(scale=g_approx_sigma),
    # initial_parameters=dict(loc=g_median, scale=g_scale),
    # initial_parameters=dict(loc=g_median, nu=jnp.ones_like(g_median)),
    batch_and_event_shape=(G.batch_shape + G.event_shape),
    name="q_z",
    # fixed params
    loc=g_median,
    # scale=g_scale,
    # scale=g_approx_sigma,
    nu=_G_parts.nu[..., 0],
)
initial_parameters = init_fn(seed=init_seed)

# if "scale" in initial_parameters._fields:
#     initial_parameters = initial_parameters._replace(
#         # scale=initial_parameters.scale[:, :1],
#         scale=g_scale.mean(axis=-1)[..., None],
#     )

if "nu" in initial_parameters._fields:
    initial_parameters = initial_parameters._replace(
        # nu=initial_parameters.nu[:, :1],
        nu=jnp.array([-0.17488582, -0.20942171])[..., None],
    )
if "power" in initial_parameters._fields:
    initial_parameters = initial_parameters._replace(
        # power=initial_parameters.power[:, :1],
        power=jnp.array([2.0, 2.0])[..., None],
    )


# _build_fn = build_fn
# def build_fn(*params):
#     distr = _build_fn(*params)
#     return tfd.Independent(distr, 1)


def mixed_log_prob(*params):
    print(params, [p.shape for p in params])
    distr = build_fn(*params)
    print(distr)
    *params, growth = params
    return G.unnormalized_log_prob(growth)

optimized_parameters, result_traces = tfp.vi.fit_surrogate_posterior_stateless(
    G.unnormalized_log_prob,
    # independent_G.unnormalized_log_prob,
    # independent_G.log_prob,
    # lambda *params: print(params, [p.shape for p in params]) or G.unnormalized_log_prob(params[-1]),
    build_surrogate_posterior_fn=build_fn,
    initial_parameters=initial_parameters,
    optimizer=optax.adam(learning_rate=0.01),
    # num_steps=1_000,
    # num_steps=500,
    num_steps=400,
    # num_steps=100,
    # num_steps=10,
    # num_steps=1,
    sample_size=10,
    # sample_size=1,
    # jit_compile=True,
    # trace_fn=lambda traceable_quantities: traceable_quantities.loss,
    trace_fn=lambda traceable_quantities: traceable_quantities,
    seed=seed,

    # gradient_estimator=tfp.vi.GradientEstimators.SCORE_FUNCTION,
    gradient_estimator=tfp.vi.GradientEstimators.DOUBLY_REPARAMETERIZED,
)
losses = result_traces
q_z = build_fn(*optimized_parameters)
print(q_z)
q_z.parameters

In [None]:
result_traces.loss[-1].max()

In [None]:
result_traces.loss[-1]

In [None]:
np.where(np.isnan(result_traces.loss))

In [None]:
result_traces._asdict()

In [None]:
_init_d = build_fn(initial_parameters)
display(_init_d.parameters)
_init_d.sample([], seed=seed)

In [None]:
fig, axs = plt.subplots(3, 2, figsize=(12, 15), layout="constrained", sharey="row")
axs[0, 0].plot(q_z.loc[0])
axs[0, 1].plot(q_z.loc[1])
axs[1, 0].plot(q_z.scale[0], marker=".", ls="")
axs[1, 1].plot(q_z.scale[1], marker=".", ls="")
axs[1, 0].plot(g_approx_sigma[0], 'k--', linewidth=3)
axs[1, 1].plot(g_approx_sigma[1], 'k--', linewidth=3)
axs[2, 0].plot(100 * abs(q_z.scale[0] - g_approx_sigma[0]) / q_z.scale[0])
axs[2, 1].plot(100 * abs(q_z.scale[1] - g_approx_sigma[1]) / q_z.scale[1])
print(q_z.nu)
if hasattr(q_z, "power"):
    print(q_z.power)

In [None]:
q_z.batch_shape

In [None]:
p_g_2 = xr.apply_ufunc(
    q_z.prob, growth_da, output_core_dims=[["sex", "age"]], vectorize=True
).assign_coords({"sex": _len_ds["sex"], "age": _len_ds["age"]})
p_g_2.shape

In [None]:
fig, axs = plt.subplots(3, 2, layout="constrained", figsize=(12, 15))

p_g.isel(sex=0).plot(ax=axs[0, 0], x="age")
p_g.isel(sex=1).plot(ax=axs[0, 1], x="age")
p_g_2.isel(sex=0).plot(ax=axs[1, 0], x="age")
p_g_2.isel(sex=1).plot(ax=axs[1, 1], x="age")
d_p_g = abs(p_g - p_g_2)
# d_p_g = xr.apply_ufunc(divide_no_nan, d_p_g, p_g)
d_p_g.isel(sex=0).plot.imshow(ax=axs[2, 0], x="age", vmin=0, vmax=1)
d_p_g.isel(sex=1).plot.imshow(ax=axs[2, 1], x="age", vmin=0, vmax=1)

In [None]:
d_p_g.max(), d_p_g.min()

In [None]:
assert False

In [None]:
G.batch_shape
G.dtype

In [None]:
Root = tfd.JointDistributionCoroutine.Root


@tfd.JointDistributionCoroutine
def _deterministic_fit_model():
    _z = jnp.zeros(G.batch_shape, dtype=G.dtype)
    # g = yield Root(G)
    mu = yield Root(tfd.Normal(_z, 1.0, name="mu"))
    sigma = yield Root(tfd.HalfCauchy(_z, 5.0, name="sigma"))
    fit_g = yield tfd.Normal(mu, sigma, name="fit_g")
    # yield tfd.Deterministic(g - fit_g, name="zero")
    yield tfd.TransformedDistribution(G, tfb.Shift(-fit_g), name="zero")


deterministic_fit_model = _deterministic_fit_model.experimental_pin(zero=0.0)
deterministic_fit_model

In [None]:
G.components_distribution.quantile(0.5)

In [None]:
G.mixture_distribution.probs_parameter().shape

In [None]:
Root = tfd.JointDistributionCoroutine.Root

g_median = same_family_mixture_quantile(G, 0.5)


@tfd.JointDistributionCoroutine
def _fit_model():
    _z = jnp.zeros(G.batch_shape, dtype=G.dtype)
    # g = yield Root(G)
    mu = yield Root(tfd.Normal(g_median, 1.0, name="mu"))
    sigma = yield Root(tfd.HalfCauchy(_z, 5.0, name="sigma"))
    fit_g = yield tfd.Normal(mu, sigma, name="fit_g")


fit_model = _fit_model  # .experimental_pin(zero=0.0)
fit_model

In [None]:
import jax

In [None]:
step_size = 0.1
num_steps = 500
burnin = 50

event_space_bijector = fit_model.experimental_default_event_space_bijector()
init_state = event_space_bijector.inverse(fit_model.sample(seed=PRNGKey(0)))
init_state, treedef = jax.tree_util.tree_flatten(init_state)


def target_log_prob_fn(*x):
    x = jax.tree_util.tree_unflatten(treedef, x)
    y = event_space_bijector.forward(x)
    p_y = fit_model.log_prob(y)
    p_g = G.log_prob(y.fit_g)
    return p_y + p_g


def trace_fn(_, pkr):
    return (
        pkr.inner_results.inner_results.target_log_prob,
        pkr.inner_results.inner_results.leapfrogs_taken,
        pkr.inner_results.inner_results.has_divergence,
        pkr.inner_results.inner_results.energy,
        pkr.inner_results.inner_results.log_accept_ratio,
    )


unconstraining_bijectors = [
    tfb.Identity(),
    tfb.Identity(),
    tfb.Identity(),
]

kernel = tfp.mcmc.NoUTurnSampler(target_log_prob_fn, step_size=step_size)
kernel = tfp.mcmc.TransformedTransitionKernel(
    inner_kernel=kernel, bijector=unconstraining_bijectors
)

hmc = tfp.mcmc.DualAveragingStepSizeAdaptation(
    inner_kernel=kernel,
    num_adaptation_steps=burnin,
    step_size_setter_fn=lambda pkr, new_step_size: pkr._replace(
        inner_results=pkr.inner_results._replace(step_size=new_step_size)
    ),
    step_size_getter_fn=lambda pkr: pkr.inner_results.step_size,
    log_accept_prob_getter_fn=lambda pkr: pkr.inner_results.log_accept_ratio,
)
chain_state, sampler_stat = tfp.mcmc.sample_chain(
    num_results=num_steps,
    num_burnin_steps=burnin,
    current_state=init_state,
    kernel=hmc,
    trace_fn=trace_fn,
    seed=PRNGKey(0),
)
chain_state, sampler_stat

In [None]:
init_state

In [None]:
init_state

In [None]:
tfp.mcmc.sample_chain

In [None]:
fit_model.sample(seed=PRNGKey(0))

In [None]:
deterministic_fit_model.sample_and_log_weight([10], seed=PRNGKey(0))

In [None]:
batched_G.quantile(0.5)

In [None]:
G.batch_shape, (
    G.mixture_distribution.batch_shape,
    G.mixture_distribution._num_categories(),
), G.components_distribution.batch_shape

In [None]:
G.sample(seed=PRNGKey(0))

In [None]:
# g = xr.DataArray(np.linspace(0.05, 0.2, 1_000), dims="growth").assign_coords(growth=lambda da: da)
g = xr.DataArray(np.linspace(0.05, 0.2, 100), dims="growth").assign_coords(
    growth=lambda da: da
)
# p_g = G.prob(np.broadcast_to(g, [*G.batch_shape, len(g)]))
p_g = xr.apply_ufunc(G.prob, g, output_core_dims=[["sex", "age"]], vectorize=True)
p_g = p_g.assign_coords({"sex": _len_ds["sex"], "age": _len_ds["age"]})
p_g.shape

In [None]:
fig, axs = plt.subplots(1, 2, layout="constrained", figsize=(12, 5))

p_g.isel(sex=0).plot(ax=axs[0], x="age")
p_g.isel(sex=1).plot(ax=axs[1], x="age")

In [None]:
_wfl_length = jnp.array(_wfl_ds["length"].values.astype(np.float32))
_wfl_mu = jnp.array(_wfl_ds["mu"].values.astype(np.float32))
_wfl_sigma = jnp.array(_wfl_ds["sigma"].values.astype(np.float32))
_wfl_nu = jnp.array(_wfl_ds["nu"].values.astype(np.float32))

_logits = L[..., None].log_prob(_wfl_length)


@tfd.JointDistributionCoroutineAutoBatched
def model():
    idx = yield tfd.Categorical(logits=_logits, name="idx")
    length = _wfl_length[idx]
    args = _wfl_mu, _wfl_sigma, _wfl_nu
    args = (jnp.take_along_axis(a, idx, axis=-1) for a in args)
    weight = yield BoxCoxColeGreen(*args, name="wfl")
    # weight = yield tfd.TransformedDistribution(
    #     tfd.Normal(0.0, 1.0),
    #     tfb.Chain([
    #         tfb.Scale(_wfl_mu[idx]),
    #         BoxCoxTransform(_wfl_nu[idx]),
    #         tfb.Scale(_wfl_sigma[idx])
    #     ]),
    #     name="wfl",
    # )
    growth = weight / length


model