In [None]:
import numpy as np
import pyro
import pyro.distributions as dist
import torch
from functools import partial
from tqdm import tqdm
import pyro.infer
from lib.util import info

import pyro.primitives

In [None]:
import numpy as np

In [None]:
torch.sum()

In [None]:
import scipy.stats
import matplotlib.pyplot as plt

def power_perturb_rvs(alpha):
    x = scipy.stats.dirichlet(np.ones_like(alpha)).rvs()
    x_pow = np.power(x, 1 / alpha)
    x_pow_norm = x_pow / x_pow.sum(-1)
    return x_pow_norm

frac = np.array([0.95, 0.05])
conc = 1e-2
x = np.concatenate([power_perturb_rvs(frac * conc) for _ in range(10000)])
x.mean(0)

In [None]:
import seaborn as sns

In [None]:
def power_perturb_model(alpha, name='x'):
    x = pyro.sample(name, dist.Dirichlet(torch.ones_like(alpha)))
    log_x_pow = torch.log(x) / alpha
    log_x_pow_norm = log_x_pow - torch.logsumexp(log_x_pow, -1, keepdim=True)
    return torch.exp(log_x_pow_norm)

x = power_perturb_model(torch.ones((100, 200))*1e-1)

sns.heatmap(x)

In [None]:
np.exp(-700)

In [None]:
def model(
    s,
    m,
    y=None,
    gamma0=1.0,
    pi0=1.0,
    rho0=1.0,
    epsilon0=0.01,
    alpha0=100.0,
    dtype=torch.float32,
    device="cpu",
):

    # Cast inputs and set device
    m, gamma0, pi0, rho0, epsilon0, alpha0 = [
        torch.tensor(v, dtype=dtype, device=device)
        for v in [m, gamma0, pi0, rho0, epsilon0, alpha0]
    ]
    if y is not None:
        y = torch.tensor(y)

    n, g = m.shape
    
    gamma_hyper = pyro.sample("gamma_hyper", dist.Gamma(gamma0, 1.0))

    with pyro.plate("position", g, dim=-1):
        with pyro.plate("strain", s, dim=-2):
            gamma = pyro.sample("gamma", dist.Beta(gamma_hyper, gamma_hyper))
    # gamma.shape == (s, g)

    rho_hyper = pyro.sample("rho_hyper", dist.Gamma(rho0, 1.0))
    rho_raw = pyro.sample(
        "rho_raw",
        dist.Dirichlet(torch.ones(s, dtype=dtype, device=device)),
    )
    log_rho_raw_pow = torch.log(rho_raw) / (rho_hyper)
    log_rho = log_rho_raw_pow - torch.logsumexp(log_rho_raw_pow, -1, keepdim=True)
    rho = pyro.deterministic('rho', torch.exp(log_rho))
    

    epsilon_hyper = pyro.sample("epsilon_hyper", dist.Beta(1.0, 1 / epsilon0))
    alpha_hyper = pyro.sample("alpha_hyper", dist.Gamma(alpha0, 1.0))
    pi_hyper = pyro.sample("pi_hyper", dist.Gamma(pi0, 1.0))

    with pyro.plate("sample", n, dim=-1):
        # Construct pi from PowerPert distribution
        # TODO: Add back rho influence
        pi_raw = pyro.sample('pi_raw', dist.Dirichlet(rho * s))
        log_pi_raw_pow = torch.log(pi_raw) / (pi_hyper)
        log_pi = log_pi_raw_pow - torch.logsumexp(log_pi_raw_pow, -1, keepdim=True)
        pi = pyro.deterministic('pi', torch.exp(log_pi))
        
        alpha = pyro.sample("alpha", dist.Gamma(alpha_hyper, 1.0)).unsqueeze(
            -1
        )
        epsilon = pyro.sample(
            "epsilon", dist.Beta(1.0, 1 / epsilon_hyper)
        ).unsqueeze(-1)
    # pi.shape == (n, s)
    # alpha.shape == epsilon.shape == (n,)

    p_noerr = pyro.deterministic("p_noerr", pi @ gamma)
    p = pyro.deterministic(
        "p", (1 - epsilon / 2) * (p_noerr) + (epsilon / 2) * (1 - p_noerr)
    )
    # p.shape == (n, g)

    y = pyro.sample(
        "y",
        dist.BetaBinomial(
            concentration1=alpha * p,
            concentration0=alpha * (1 - p),
            total_count=m,
        ),
        obs=y,
    )
    # y.shape == (n, g)
    return y

def conditioned_model(
    model, data={}, dtype=torch.float32, device="cpu", **kwargs,
):
    data = {
        k: torch.tensor(v, dtype=dtype, device=device) for k, v in data.items()
    }
    return partial(
        pyro.condition(model, data=data), dtype=dtype, device=device, **kwargs,
    )


def find_map(
    model,
    lag=10,
    stop_at=1.0,
    max_iter=int(1e5),
    learning_rate=1e-0,
    clip_norm=100.0,
    auto_guide=pyro.infer.autoguide.AutoLaplaceApproximation,
    num_samples=1,
):
    guide = auto_guide(model)
    svi = pyro.infer.SVI(
        model,
        guide,
        pyro.optim.Adamax(
            optim_args={"lr": learning_rate},
            clip_args={"clip_norm": clip_norm},
        ),
        loss=pyro.infer.JitTrace_ELBO(),
    )

    pyro.clear_param_store()
    pbar = tqdm(range(max_iter), position=0, leave=True)
    history = []
    try:
        for i in pbar:
            elbo = svi.step()

            if np.isnan(elbo):
                break

            # Fit tracking
            history.append(elbo)

            # Reporting/Breaking
            if i < 2:
                pbar.set_postfix({"ELBO": history[-1]})
            elif i < lag + 1:
                pbar.set_postfix(
                    {
                        "ELBO": history[-1],
                        "delta_1": history[-2] - history[-1],
                    }
                )
            else:
                delta_lag = (history[-lag] - history[-1]) / lag
                pbar.set_postfix(
                    {
                        "ELBO": history[-1],
                        "delta_1": history[-2] - history[-1],
                        f"delta_{lag}": delta_lag,
                    }
                )
                if delta_lag < stop_at:
                    info("Optimization converged")
                    break
    except KeyboardInterrupt:
        info("Optimization interrupted")
    pbar.refresh()
    assert delta_lag < stop_at, (
        f"Reached {args.max_iter} iterations with a per-step improvement of "
        f"{args.delta_lag}. Consider setting --max-iter "
        f"or --stop-at larger; increasing --learning-rate may also help, "
        f"although it could also lead to numerical issues."
    )
    # Gather MAP from parameter-store
    mapest = {
        k: v.detach().cpu().numpy().squeeze()
        for k, v in pyro.infer.Predictive(
            model, guide=guide, num_samples=num_samples,
        )().items()
    }
    return mapest, np.array(history)

In [None]:
n_sim, g_sim = 300, 200
m_sim = 10 * np.ones((n_sim, g_sim))
s_sim = 100

model_sim = conditioned_model(
    model,
    data=dict(
        gamma_hyper=0.01,
        pi_hyper=1e-2,
        rho_hyper=1e-0,
        epsilon_hyper=0.01,
        alpha_hyper=1000,
    ),
    s=s_sim,
    m=m_sim,
)

sim = pyro.infer.Predictive(model_sim, num_samples=1)()
sim = {k: sim[k].detach().cpu().numpy().squeeze() for k in sim}
#sim

In [None]:
plt.hist(sim['pi'].max(1))

In [None]:
plt.scatter(sim['rho'], sim['pi'].mean(0))

In [None]:
sns.clustermap(sim['pi'])

In [None]:
sns.clustermap(sim['gamma'])

In [None]:
model_fit = conditioned_model(
    model,
    data=dict(
        gamma_hyper=0.01,
        pi_hyper=1e-2,
        rho_hyper=1e-0,
        epsilon_hyper=0.01,
        alpha_hyper=1000,
        y=sim['y'],
    ),
    s=s_sim,
    m=m_sim,
)

mapest, history = find_map(
    model_fit, learning_rate=1e-1, lag=200,
    auto_guide=partial(pyro.infer.autoguide.AutoLowRankMultivariateNormal, rank=10),
    num_samples=1000,
)

In [None]:
sns.clustermap(mapest['gamma'].mean(0), vmin=0, vmax=1)

In [None]:
sns.clustermap(mapest['pi'][0])

In [None]:
plt.scatter(mapest['rho'], mapest['pi'].mean(0))

In [None]:
np.abs(mapest['p_noerr'] - sim['p']).sum() / (n_sim * g_sim)

In [None]:
mapest.keys()

In [None]:
plt.plot(mapest['alpha'])

In [None]:
scipy.stats.dirichlet.logpdf?

In [None]:
# TODO: Explore distributions formed by powering of dirichlet distributed random variables.

def closure(x):
    x = np.asarray(x)
    return x / x.sum(-1)

xx = np.linspace(0.0000001, 0.9999999, num=10000)
alpha = np.array([1.01, 1, 1])
log_prob = np.apply_along_axis(lambda x: scipy.stats.dirichlet(alpha).logpdf([x, (1 - x)/2, (1-x)/2]), 0, xx)
plt.plot(xx, log_prob)

In [None]:
def closure(x):
    x = np.asarray(x)
    return x / x.sum(-1, keepdims=True)

def repeat(f, n=1):
    return np.stack([f() for _ in range(n)])


p = 6/11
alpha = 2.0

p = np.array([p, (1 - p)])
p_raised = (p / p.min())

x = closure(scipy.stats.dirichlet(p_raised).rvs(100000)**alpha)[:, 0]
plt.hist(x, bins=100)
print(np.geomean(x))
None