In [None]:
import pyro
import pyro.distributions as dist
import torch
from functools import partial
import arviz as az
from pyro.ops.contract import einsum
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm

## First attempt

In [None]:
def model(
    s,
    g=None,
    n=None,
    a=None,
    m=None,
    y=None,    
    gamma_hyper=torch.tensor(0.),
    rho_hyper=torch.tensor(0.),
    pi_hyper=torch.tensor(0.),
    epsilon_hyper=torch.tensor(100.),
    alpha_hyper=torch.tensor(3.)
):
    
    if y is not None:
        if m is not None:
            assert torch.all(y.sum(-1) == m)
        if a is not None:
            assert y.shape[-1] == a
        if g is not None:
            assert y.shape[-2] == g
        if n is not None:
            assert y.shape[-3] == n
        # Set them using y anyway.
        m = y.sum(-1)
        n, g, a = y.shape[-3:]
    
#     allele_plate = pyro.plate('allele', a)
    strain_plate = pyro.plate('strain', s, dim=-1)
    position_plate = pyro.plate('position', g, dim=-2)
    sample_plate = pyro.plate('sample', n, dim=-3)
    
    with position_plate:
        with strain_plate:
            gamma = pyro.sample('gamma', dist.Dirichlet(torch.ones(a) * torch.exp(gamma_hyper)))
#    assert gamma.shape == (g, s, a), gamma.shape
    
    rho = pyro.sample('rho', dist.Dirichlet(torch.ones(s) * torch.exp(rho_hyper))).unsqueeze(-1)
    with sample_plate:
        pi = pyro.sample('pi', dist.Dirichlet(rho * torch.exp(pi_hyper)))
        p_noerr = pyro.deterministic('p_noerr', (pi * gamma).sum(-2))
        epsilon = pyro.sample('epsilon', dist.Beta(1., epsilon_hyper))
        p = pyro.deterministic('p', (1 - epsilon / a) * (p_noerr) + (epsilon / a) * (1 - p_noerr))
        alpha = pyro.sample('alpha', dist.Normal(5., alpha_hyper).expand([1]))
    y = pyro.sample('y', dist.DirichletMultinomial(p * torch.exp(alpha), total_count=m), obs=y)
    return gamma, rho, pi, epsilon, alpha, y


model_sim = partial(model,
    s=3,
    n=5,
    g=10,
    a=2,
    m=10,
    gamma_hyper=torch.tensor(-5.),
    rho_hyper=torch.tensor(-1.),
    pi_hyper=torch.tensor(10.),
    epsilon_hyper=torch.tensor(100.),
    alpha_hyper=torch.tensor(3.),
)
sim_gamma, sim_rho, sim_pi, sim_epsilon, sim_alpha, sim_y = model_sim()

In [None]:
model_fit = partial(
    model,
    s=3,
    gamma_hyper=torch.tensor(0.),
    rho_hyper=torch.tensor(0.),
    pi_hyper=torch.tensor(0.),
    epsilon_hyper=torch.tensor(100.),
    alpha_hyper=torch.tensor(3.),
)

nuts_kernel = pyro.infer.NUTS(model_fit, jit_compile=True, ignore_jit_warnings=True)
mcmc = pyro.infer.MCMC(
    nuts_kernel,
    num_samples=5,
    warmup_steps=5,
    num_chains=1,
)
mcmc.run(y=y)

posterior_samples = mcmc.get_samples()
posterior_predictive = pyro.infer.Predictive(model_fit, posterior_samples).get_samples(
    y=y
)
prior = pyro.infer.Predictive(model_fit, num_samples=500).get_samples(
    y=y
)

pyro_data = az.from_pyro(
    mcmc,
    prior=prior,
    posterior_predictive=posterior_predictive,
)

In [None]:
#guide = pyro.infer.autoguide.AutoDiagonalNormal(model)
#guide = pyro.infer.autoguide.AutoLowRankMultivariateNormal(model_fit, rank=10)
#guide = pyro.infer.autoguide.AutoLaplaceApproximation(model, )
#guide = pyro.infer.autoguide.AutoIAFNormal(model, hidden_dim=[3], num_transforms=2)
guide = pyro.infer.autoguide.AutoDelta(model_fit)

#opt = pyro.optim.Adam({"lr": 0.0001})
opt = pyro.optim.RMSprop({"lr": 0.0005})

svi = pyro.infer.SVI(model_fit, guide, opt, loss=pyro.infer.Trace_ELBO())

pyro.clear_param_store()
for i in range(10000):
    elbo = svi.step(y=y)
    if i % 1000 == 0:
        print(f"Elbo loss ({i}): {elbo:0.2e}")

In [None]:
model_predict = partial(model,
    s=3,
    n=5,
    g=10,
    a=2,
    m=y_sim.sum(-1),
    gamma_hyper=torch.tensor(-5.),
    rho_hyper=torch.tensor(-1.),
    pi_hyper=torch.tensor(10.),
    epsilon_hyper=torch.tensor(100.),
    alpha_hyper=torch.tensor(3.),
)

svi_predictive = pyro.infer.Predictive(model_predict, guide=guide, num_samples=500, parallel=True)
svi_posterior = {k: v.detach().numpy()
                 for k, v
                 in svi_predictive(y=y).items()}
# svi_prior = pyro.infer.Predictive(model_predict, num_samples=500).get_samples(
#     y=y
# )
# posterior_predictive = svi_predictive()['y']

In [None]:
svi_posterior['gamma']

## Take #2

In [None]:
def model(
    s,
    g=None,
    n=None,
    a=None,
    m=None,
    y=None,    
    gamma_hyper=torch.tensor(0.),
    rho_hyper=torch.tensor(0.),
    pi_hyper=torch.tensor(0.),
    epsilon_hyper=torch.tensor(100.),
    alpha_hyper=torch.tensor(3.)
):
    
    if y is not None:
        if m is not None:
            assert torch.all(y.sum(-1) == m)
        if a is not None:
            assert y.shape[-1] == a
        if g is not None:
            assert y.shape[-2] == g
        if n is not None:
            assert y.shape[-3] == n
        # Set them using y anyway.
        m = y.sum(-1)
        n, g, a = y.shape[-3:]
    
#    allele_plate = pyro.plate('allele', a, dim=-1)
#    strain_plate = 
#    position_plate = pyro.plate('position', g, dim=-2)
#    sample_plate = pyro.plate('sample', n, dim=-3)
    
    with pyro.plate('position', g, dim=-2), pyro.plate('strain', s, dim=-1):
        gamma = pyro.sample('gamma', dist.Dirichlet(torch.ones(a) * torch.exp(gamma_hyper)))
    
    rho = pyro.sample('rho', dist.Dirichlet(torch.ones(s) * torch.exp(rho_hyper)))
    with pyro.plate('sample', n, dim=-1):
        pi = pyro.sample('pi', dist.Dirichlet(rho * torch.exp(pi_hyper)))
        p_noerr = pyro.deterministic('p_noerr', (pi.unsqueeze(-1).unsqueeze(-3) * gamma).sum(-2))
        epsilon = pyro.sample('epsilon', dist.Beta(1., epsilon_hyper))
        p = pyro.deterministic(
            'p',
            (1 - epsilon.unsqueeze(-1).unsqueeze(-1) / a) * (p_noerr) +
            (epsilon.unsqueeze(-1).unsqueeze(-1) / a) * (1 - p_noerr)
        )
        alpha = pyro.sample('alpha', dist.Normal(5., alpha_hyper))

    y = pyro.sample('y', dist.DirichletMultinomial(p * torch.exp(alpha.unsqueeze(-1).unsqueeze(-1)), total_count=m), obs=y)
    return gamma, rho, pi, p_noerr, epsilon, alpha, y


model_sim = partial(model,
    s=3,
    n=5,
    g=10,
    a=2,
    m=10,
    gamma_hyper=torch.tensor(-5.),
    rho_hyper=torch.tensor(-1.),
    pi_hyper=torch.tensor(10.),
    epsilon_hyper=torch.tensor(100.),
    alpha_hyper=torch.tensor(3.),
)
sim_gamma, sim_rho, sim_pi, sim_p_noerr, sim_epsilon, sim_alpha, sim_y = model_sim()
sim_gamma.shape, sim_rho.shape, sim_pi.shape, sim_p_noerr.shape, sim_epsilon.shape, sim_alpha.shape, sim_y.shape

In [None]:
model_fit = partial(
    model,
    s=3,
    gamma_hyper=torch.tensor(0.),
    rho_hyper=torch.tensor(0.),
    pi_hyper=torch.tensor(0.),
    epsilon_hyper=torch.tensor(100.),
    alpha_hyper=torch.tensor(3.),
)

pyro.clear_param_store()
nuts_kernel = pyro.infer.NUTS(model_fit, jit_compile=True, ignore_jit_warnings=True)
mcmc = pyro.infer.MCMC(
    nuts_kernel,
    num_samples=50,
    warmup_steps=20,
    num_chains=1,
)
mcmc.run(y=sim_y)

posterior_samples = mcmc.get_samples()
posterior_predictive = pyro.infer.Predictive(model_fit, posterior_samples).get_samples(
    y=sim_y
)
prior = pyro.infer.Predictive(model_fit, num_samples=500).get_samples(
    y=sim_y
)

pyro_data = az.from_pyro(
    mcmc,
    prior=prior,
    posterior_predictive=posterior_predictive,
)

In [None]:
#guide = pyro.infer.autoguide.AutoDiagonalNormal(model)
#guide = pyro.infer.autoguide.AutoLowRankMultivariateNormal(model_fit, rank=10)
#guide = pyro.infer.autoguide.AutoLaplaceApproximation(model, )
#guide = pyro.infer.autoguide.AutoIAFNormal(model, hidden_dim=[3], num_transforms=2)
guide = pyro.infer.autoguide.AutoDelta(model_fit)

#opt = pyro.optim.Adam({"lr": 0.0001})
opt = pyro.optim.RMSprop({"lr": 0.0005})

svi = pyro.infer.SVI(model_fit, guide, opt, loss=pyro.infer.Trace_ELBO())

pyro.clear_param_store()
for i in range(10000):
    elbo = svi.step(y=sim_y)
    if i % 1000 == 0:
        print(f"Elbo loss ({i}): {elbo:0.2e}")

In [None]:
model_predict = partial(model,
    s=3,
    n=5,
    g=10,
    a=2,
    m=sim_y.sum(-1),
    gamma_hyper=torch.tensor(-5.),
    rho_hyper=torch.tensor(-1.),
    pi_hyper=torch.tensor(10.),
    epsilon_hyper=torch.tensor(100.),
    alpha_hyper=torch.tensor(3.),
)

svi_predictive = pyro.infer.Predictive(model_predict, guide=guide, num_samples=500, parallel=True)
# svi_posterior = {k: v.detach().numpy()
#                  for k, v
#                  in svi_predictive(y=sim_y).items()}
svi_prior = pyro.infer.Predictive(model_predict, num_samples=500).get_samples(
)
posterior_predictive = svi_predictive()['y']

## Third time's the charm

In [None]:
def model(
    s,
    m,
    y=None,
    gamma_hyper=torch.tensor(0.),
    pi_hyper=torch.tensor(0.),
    rho_hyper=torch.tensor(0.),
    epsilon_hyper=torch.tensor(100.),
):
    
    n, g = m.shape
#    if y is not None:
#        assert y.shape == m.shape
    
    with pyro.plate('position', g, dim=-1), pyro.plate('strain', s, dim=-2):
        gamma = pyro.sample('gamma', dist.Beta(torch.exp(-gamma_hyper), torch.exp(-gamma_hyper)))
#    assert gamma.shape[-2:] == torch.Size([s, g])
    
    rho = pyro.sample('rho', dist.Dirichlet(torch.ones(s) * torch.exp(-rho_hyper)))
    with pyro.plate('sample', n, dim=-1):
        pi = pyro.sample('pi', dist.Dirichlet(rho * torch.exp(-pi_hyper)))
#        alpha = pyro.sample('alpha', dist.Normal(5., alpha_hyper))
#    assert pi.shape[-2:] == torch.Size([n, s])

        epsilon = pyro.sample('epsilon', dist.Beta(1., epsilon_hyper)).unsqueeze(-1)
    p_noerr = pyro.deterministic('p_noerr', pi @ gamma)
    p = pyro.deterministic(
        'p',
        (1 - epsilon) * (p_noerr) +
        (epsilon) * (1 - p_noerr)
    )
#    assert p.shape[-2:] == torch.Size([n, g])

    # TODO: Add overdispersion?
    y = pyro.sample('y', dist.Binomial(total_count=m, probs=p), obs=y)
#    assert y.shape[-2:] == torch.Size([n, g])
    return y

n, g = 5, 100
depth = 10
m = torch.ones((n, g)) * depth

model_sim = partial(model,
    s=3,
    m=m,
    gamma_hyper=torch.tensor(5.),
    pi_hyper=torch.tensor(1.),
    rho_hyper=torch.tensor(1.),
    epsilon_hyper=torch.tensor(100.),
#    alpha_hyper=torch.tensor(3.),
)

trace = pyro.poutine.trace(model_sim).get_trace()
trace.compute_log_prob()
print(trace.format_shapes())

sim = pyro.infer.Predictive(model_sim, num_samples=1)()
sim_y = sim['y'].squeeze()

#sim_gamma.shape, sim_rho.shape, sim_pi.shape, sim_p_noerr.shape, sim_epsilon.shape, sim_alpha.shape, sim_y.shape

In [None]:
model_fit = partial(
    model,
    s=5,
    m=m,
    gamma_hyper=torch.tensor(5.),
    pi_hyper=torch.tensor(0.),
    rho_hyper=torch.tensor(1.),
    epsilon_hyper=torch.tensor(100.),
#    alpha_hyper=torch.tensor(3.),
)

In [None]:
#guide = pyro.infer.autoguide.AutoDiagonalNormal(model)
#guide = pyro.infer.autoguide.AutoLowRankMultivariateNormal(model_fit, rank=10)
guide = pyro.infer.autoguide.AutoLaplaceApproximation(model_fit, )
#guide = pyro.infer.autoguide.AutoIAFNormal(model_fit, hidden_dim=[4000], num_transforms=1)
#guide = pyro.infer.autoguide.AutoDelta(model_fit)

opt = pyro.optim.Adam({"lr": 0.01})
#opt = pyro.optim.RMSprop({"lr": 0.001})

svi = pyro.infer.SVI(model_fit, guide, opt, loss=pyro.infer.Trace_ELBO())

pyro.clear_param_store()

pbar = tqdm(range(10000))
history = []
for i in pbar:
    elbo = svi.step(y=sim_y)
    history.append(elbo)
    if i % 100 == 0:
        pbar.set_postfix({'ELBO': elbo})

In [None]:
plt.plot(history)

In [None]:
model_predict = partial(model,
    s=3,
    m=m,
    gamma_hyper=torch.tensor(-2.),
    pi_hyper=torch.tensor(0.),
    epsilon_hyper=torch.tensor(100.),
)

svi_predictive = pyro.infer.Predictive(model_predict, guide=guide, num_samples=1)
svi_posterior = {k: v.detach().numpy()
                 for k, v
                 in svi_predictive(y=sim_y).items()}
svi_prior = pyro.infer.Predictive(model_predict, num_samples=500).get_samples(
)
posterior_predictive = svi_predictive()['y']

In [None]:
sns.heatmap(sim['pi'].squeeze().numpy(), vmin=0, vmax=1)

In [None]:
sns.heatmap(svi_posterior['pi'].mean(0).mean(0), vmin=0, vmax=1)
#sns.heatmap(svi_posterior['gamma'].mean(0).T)

In [None]:
sns.heatmap(sim['gamma'].squeeze().numpy().T)

In [None]:
sns.heatmap(svi_posterior['gamma'].mean(0).T)

In [None]:
plt.scatter(sim['epsilon'].squeeze(), svi_posterior['epsilon'].squeeze())