In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl
from functools import partial
from tqdm import tqdm
import numpy as np

In [None]:
import jax.numpy as jnp
import numpyro as npyro
import numpyro.distributions as npyro_dist
import jax

#npyro.enable_x64(True)

def model(
    s,
    m,
    gamma_hyper=1.0,
    pi_hyper=1.0,
    rho_hyper=1.0,
    epsilon_hyper=0.01,
    alpha_hyper=1000.0,
    y=None,
):

    n, g = m.shape

    with npyro.plate("position", g, dim=-1):
        with npyro.plate("strain", s, dim=-2):
            gamma = npyro.sample("gamma", npyro_dist.Beta(gamma_hyper, gamma_hyper))
    # gamma.shape == (s, g)

    rho = npyro.sample(
        "rho",
        npyro_dist.Dirichlet(jnp.ones(s) * rho_hyper),
    )

    with npyro.plate("sample", n, dim=-1):
        pi = npyro.sample("pi", npyro_dist.Dirichlet(rho * s * pi_hyper))
        alpha = npyro.sample("alpha", npyro_dist.Gamma(alpha_hyper, 1.0)).reshape(
            (-1, 1)
        )
        epsilon = npyro.sample(
            "epsilon", npyro_dist.Beta(1.0, 1 / epsilon_hyper)
        ).reshape((-1, 1))
    # pi.shape == (n, s)
    # alpha.shape == epsilon.shape == (n,)

    p_noerr = npyro.deterministic("p_noerr", pi @ gamma)
    p = npyro.deterministic(
        "p", (1 - epsilon / 2) * (p_noerr) + (epsilon / 2) * (1 - p_noerr)
    )
    # p.shape == (n, g)

    y = npyro.sample(
        "y",
        npyro_dist.BetaBinomial(
            concentration1=alpha * p,
            concentration0=alpha * (1 - p),
            total_count=m,
        ),
        obs=y
    )
    # y.shape == (n, g)

In [None]:
sim_m = 100 * jnp.ones((10, 20), dtype=int)

prior = partial(
    model,
    s=4, m=sim_m, gamma_hyper=1e-2, pi_hyper=0.1, rho_hyper=0.5
)

rng0 = jax.random.PRNGKey(1)
sim = (
    npyro.infer.Predictive(prior, num_samples=1)
    (rng0)
)

In [None]:
sns.heatmap(sim['pi'].squeeze())

In [None]:
sns.heatmap(sim['gamma'].squeeze())

In [None]:
sim['y']

In [None]:
# Start from this source of randomness. We will split keys for subsequent operations.
rng1, rng2 = jax.random.split(rng0)

In [None]:
from numpyro.infer import autoguide
from jax import lax

guide = npyro.infer.autoguide.AutoDiagonalNormal(prior)
opt = npyro.optim.ClippedAdam(step_size=0.1, clip_norm=1000.)
svi = npyro.infer.SVI(prior, guide, opt, loss=npyro.infer.Trace_ELBO(), y=sim['y'])

svi_init = svi.init(rng1)
svi_state = svi_init
# svi_state, losses = lax.scan(lambda state, i: svi.update(state), svi_state, jnp.arange(2000))


pbar = tqdm(range(2000))
history = []
delta_history = []
# trace_epsilon_interval = []
# trace_gamma_a = []
# trace_gamma_b = []
# trace_gamma_loc = []
# trace_alpha_log = []
# trace_pi_simplex = []
for i in pbar:
    svi_state, elbo = svi.update(svi_state)
    
    if np.isnan(elbo):
        break

    # Fit tracking
    history.append(elbo)
    
    # Reporting/Breaking
    if (i % 1 == 0):
        if i > 1:
            pbar.set_postfix({'ELBO': history[-1], 'delta': history[-2] - history[-1]})
#         trace_epsilon_interval.append(pyro.get_param_store()['epsilon_interval'].detach().numpy().copy())
#         trace_gamma_a.append(pyro.get_param_store()['gamma_a'].detach().numpy().copy())
#         trace_gamma_b.append(pyro.get_param_store()['gamma_b'].detach().numpy().copy())
# #         trace_gamma_loc.append(pyro.get_param_store()['gamma_loc'].detach().numpy().copy())
#         trace_alpha_log.append(pyro.get_param_store()['alpha_log'].detach().numpy().copy())
#         trace_pi_simplex.append(pyro.get_param_store()['pi_simplex'].detach().numpy().copy())
#     if np.mean(delta_history[-1000:]) < 0.0001:
#         break

In [None]:
guide(svi_state)

In [None]:
plt.plot(history)

In [None]:
guide(svi.init(rng1))

In [None]:
svi_state

In [None]:
guide(svi_state)

In [None]:
guide.sample_posterior(rng0, guide.get_transform(svi_state))

In [None]:
svi_point = npyro.infer.Predictive(posterior, guide=svi_state, num_samples=1, return_sites=['pi', 'gamma'])
mapest = {k: v
                 for k, v
                 in svi_point(rng1).items()}
#posterior_predictive = svi_predictive()['y']

#fit_pi = fit_pi.rename(columns=lambda i: f"fit_{i}")

In [None]:
sns.heatmap(sim['gamma'].squeeze())

In [None]:
sns.heatmap(mapest['gamma'].squeeze())

In [None]:
guide = pyro.infer.autoguide.AutoNormal(model_fit, )

opt = pyro.optim.Adamax({"lr": 1e-1}, {"clip_norm": 100.})
#opt = pyro.optim.RMSprop({"lr": 0.001})

svi = pyro.infer.SVI(
    model_fit,
    _guide,
    opt,
    loss=pyro.infer.JitTrace_ELBO()
)

pyro.clear_param_store()

pbar = tqdm(range(10000))
history = []
delta_history = []
# trace_epsilon_interval = []
# trace_gamma_a = []
# trace_gamma_b = []
# trace_gamma_loc = []
# trace_alpha_log = []
# trace_pi_simplex = []
for i in pbar:
    elbo = svi.step(
        y=y_obs,
    )
    
    if np.isnan(elbo):
        break

    # Fit tracking
    history.append(elbo)
    
    # Reporting/Breaking
    if (i % 1 == 0):
        if i > 1:
            pbar.set_postfix({'ELBO': history[-1], 'delta': history[-2] - history[-1]})
#         trace_epsilon_interval.append(pyro.get_param_store()['epsilon_interval'].detach().numpy().copy())
#         trace_gamma_a.append(pyro.get_param_store()['gamma_a'].detach().numpy().copy())
#         trace_gamma_b.append(pyro.get_param_store()['gamma_b'].detach().numpy().copy())
# #         trace_gamma_loc.append(pyro.get_param_store()['gamma_loc'].detach().numpy().copy())
#         trace_alpha_log.append(pyro.get_param_store()['alpha_log'].detach().numpy().copy())
#         trace_pi_simplex.append(pyro.get_param_store()['pi_simplex'].detach().numpy().copy())
#     if np.mean(delta_history[-1000:]) < 0.0001:
#         break

        
pbar.refresh()