In [None]:
import pandas as pd
from lib.util import info, idxwhere
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import scipy as sp

import pyro
import pyro.distributions as dist
import torch
from functools import partial
import arviz as az
from pyro.ops.contract import einsum
import seaborn as sns
from tqdm import tqdm

In [None]:
def NegativeBinomialReparam(mu, r):
    p = 1 / ((r / mu) + 1)
    return dist.NegativeBinomial(
        total_count=r,
        probs=p
    )

def as_torch(dtype=torch.float32, device="cpu", **kwargs):
    # Cast inputs and set device
    return {k: torch.tensor(kwargs[k], dtype=dtype, device=device) for k in kwargs}

# Model0: Dirichlet

In [None]:
def model0(
    n,
    g,
    s,
    gamma_hyper=torch.tensor(1.),
    rho_hyper=torch.tensor(1.),
    pi_hyper=torch.tensor(1.),
    m_hyper_mu=torch.tensor(10.),
    m_hyper_r=torch.tensor(1.),
    epsilon_hyper=torch.tensor(0.01),
    alpha_hyper=torch.tensor(100.),
):
    
    with pyro.plate('position', g, dim=-1):
        with pyro.plate('strain', s, dim=-2):
            gamma = pyro.sample(
                'gamma', dist.Beta(gamma_hyper, gamma_hyper)
            )
    
#     rho_ = pyro.sample('rho_', dist.LogNormal(0, 1 / rho_hyper).expand([s]).to_event())
#     rho = pyro.deterministic('rho', rho_ / rho_.sum())
    rho = pyro.sample('rho', dist.Dirichlet(torch.ones(s) * rho_hyper))
    
    with pyro.plate('sample', n, dim=-1):
        pi = pyro.sample('pi', dist.Dirichlet(rho * pi_hyper * s))
        alpha = pyro.sample('alpha', dist.Gamma(alpha_hyper, 1.)).unsqueeze(-1)
        epsilon = pyro.sample('epsilon', dist.Beta(1., 1 / epsilon_hyper)).unsqueeze(-1)
        
    m = pyro.sample('m', NegativeBinomialReparam(m_hyper_mu, m_hyper_r).expand([n, g]))

    p_noerr = pyro.deterministic('p_noerr', pi @ gamma)
    p = pyro.deterministic('p',
        (1 - epsilon / 2) * (p_noerr) +
        (epsilon / 2) * (1 - p_noerr)
    )
        
    y = pyro.sample(
        'y',
        dist.BetaBinomial(
            concentration1=alpha * p,
            concentration0=alpha * (1 - p),
            total_count=m
        ),
    )
    return y

In [None]:
n, g, s = 500, 1000, 100

model0_sim = partial(
    pyro.condition(
        model0,
        data={
        },
    ),
    s=s,
    g=g,
    n=n,
    **as_torch(
        gamma_hyper=0.1,
        pi_hyper=0.001,
        rho_hyper=1.,
    )
)

trace = pyro.poutine.trace(model0_sim).get_trace()
trace.compute_log_prob()
print(trace.format_shapes())

In [None]:
sim = pyro.infer.Predictive(model0_sim, num_samples=1)()
sim = {k: sim[k].detach().cpu().numpy().squeeze() for k in sim.keys()}

sns.heatmap(sim['pi'])

In [None]:
plt.plot(np.sort(sim['rho']))

In [None]:
# gamma_hyper_fit = torch.autograd.Variable(torch.tensor(1.))
# pi_hyper_fit = torch.autograd.Variable(torch.tensor(1.))
# rho_hyper_fit = torch.autograd.Variable(torch.tensor(1.))


model0_fit = partial(
    pyro.condition(
        model0,
        data={
            'm': torch.tensor(sim['m']),
            'y': torch.tensor(sim['y']),
        },
    ),
    s=s,
    g=g,
    n=n,
    **as_torch(
        pi_hyper=1.0,
        rho_hyper=1.0,
        gamma_hyper=1.0,
    )
)

_guide = pyro.infer.autoguide.AutoLaplaceApproximation(model0_fit)
opt = pyro.optim.Adamax({"lr": 1e-0}, {"clip_norm": 100.})
svi = pyro.infer.SVI(
    model0_fit,
    _guide,
    opt,
    loss=pyro.infer.JitTrace_ELBO()
)
pyro.clear_param_store()

n_iter = int(5e2)
# step_hypers_at = int(5e2)
# start_pi_shift_at = int(5e3)
# pi_hyper_schedule = np.concatenate([
#     np.logspace(0.01, 0, start_pi_shift_at // step_hypers_at),
#     np.logspace(0, -1.0, (n_iter - start_pi_shift_at) // step_hypers_at),
# ]).astype('float32')
# rho_hyper_schedule = np.logspace(0, -2, num=n_iter // step_hypers_at).astype('float32')
# gamma_hyper_schedule = np.logspace(0, -2, num=n_iter // step_hypers_at).astype('float32')
# plt.scatter(rho_hyper_schedule, pi_hyper_schedule, c=np.linspace(0, 1, num=pi_hyper_schedule.shape[0]))
# plt.yscale('log')
# plt.xscale('log')

history = []

In [None]:
pbar = tqdm(range(n_iter))
for i in pbar:
#     pi_hyper_fit.data = torch.tensor(pi_hyper_schedule[i // step_hypers_at])
#     rho_hyper_fit.data = torch.tensor(rho_hyper_schedule[i // step_hypers_at])
#     gamma_hyper_fit.data = torch.tensor(gamma_hyper_schedule[i // step_hypers_at])
#     pi_hyper_fit.data = torch.tensor(1.)
#     rho_hyper_fit.data = torch.tensor(1.)
#     gamma_hyper_fit.data = torch.tensor(1.)

    elbo = svi.step()
    
    if np.isnan(elbo):
        break

    # Fit tracking
    history.append(elbo)
    
    # Reporting/Breaking
    if (i % 1 == 0):
        if i > 1:
            pbar.set_postfix({
                'ELBO': history[-1],
                'delta': history[-2] - history[-1],
#                 'pi_hyper': pi_hyper_fit,
#                 'rho_hyper': rho_hyper_fit,
#                 'gamma_hyper': gamma_hyper_fit,
            })

In [None]:
plt.plot(history)

In [None]:
est = pyro.infer.Predictive(model0_fit, guide=_guide, num_samples=1)()
est = {k: est[k].detach().cpu().numpy().squeeze() for k in sim.keys()}

In [None]:
sns.heatmap(est['pi'])

In [None]:
sns.heatmap(est['gamma'].T)

In [None]:
plt.plot(np.sort(est['rho']))

# Model1: Gumbel-Softmax

In [None]:
def model1(
    n,
    g,
    s,
    gamma_hyper=torch.tensor(1.),
    rho_hyper=torch.tensor(1.),
    pi_hyper=torch.tensor(1.),
    m_hyper_mu=torch.tensor(10.),
    m_hyper_r=torch.tensor(1.),
    epsilon_hyper=torch.tensor(0.01),
    alpha_hyper=torch.tensor(100.),
):
    
    with pyro.plate('position', g, dim=-1):
        with pyro.plate('strain', s, dim=-2):
            gamma = pyro.sample(
                'gamma', dist.RelaxedBernoulli(temperature=gamma_hyper, probs=torch.tensor(0.5))
            )
    
#     rho_ = pyro.sample('rho_', dist.LogNormal(0, 1 / rho_hyper).expand([s]).to_event())
#     rho = pyro.deterministic('rho', rho_ / rho_.sum())
    rho = pyro.sample('rho', dist.RelaxedOneHotCategorical(temperature=rho_hyper, logits=torch.zeros(s)))
    
    with pyro.plate('sample', n, dim=-1):
        pi = pyro.sample('pi', dist.RelaxedOneHotCategorical(temperature=pi_hyper, probs=rho))
        alpha = pyro.sample('alpha', dist.Gamma(alpha_hyper, 1.)).unsqueeze(-1)
        epsilon = pyro.sample('epsilon', dist.Beta(1., 1 / epsilon_hyper)).unsqueeze(-1)
        
    m = pyro.sample('m', NegativeBinomialReparam(m_hyper_mu, m_hyper_r).expand([n, g]))

    p_noerr = pyro.deterministic('p_noerr', pi @ gamma)
    p = pyro.deterministic('p',
        (1 - epsilon / 2) * (p_noerr) +
        (epsilon / 2) * (1 - p_noerr)
    )
        
    y = pyro.sample(
        'y',
        dist.BetaBinomial(
            concentration1=alpha * p,
            concentration0=alpha * (1 - p),
            total_count=m
        ),
    )
    return y

In [None]:
n, g, s = 500, 1000, 100

model1_sim = partial(
    pyro.condition(
        model1,
        data={
        },
    ),
    s=s,
    g=g,
    n=n,
    **as_torch(
        gamma_hyper=0.1,
        pi_hyper=0.1,
        rho_hyper=1.,
    )
)

trace = pyro.poutine.trace(model1_sim).get_trace()
trace.compute_log_prob()
print(trace.format_shapes())

In [None]:
# gamma_hyper_fit = torch.autograd.Variable(torch.tensor(1.))
# pi_hyper_fit = torch.autograd.Variable(torch.tensor(1.))
# rho_hyper_fit = torch.autograd.Variable(torch.tensor(1.))


model1_fit = partial(
    pyro.condition(
        model1,
        data={
            'm': torch.tensor(sim['m']),
            'y': torch.tensor(sim['y']),
        },
    ),
    s=s,
    g=g,
    n=n,
    **as_torch(
        gamma_hyper=1.,
        pi_hyper=1.,
        rho_hyper=1.,
    )
)

_guide = pyro.infer.autoguide.AutoLaplaceApproximation(model1_fit)
opt = pyro.optim.Adamax({"lr": 1e-0}, {"clip_norm": 100.})
svi = pyro.infer.SVI(
    model1_fit,
    _guide,
    opt,
    loss=pyro.infer.JitTrace_ELBO()
)
pyro.clear_param_store()

n_iter = int(5e2)
# step_hypers_at = int(5e2)
# start_pi_shift_at = int(5e3)
# pi_hyper_schedule = np.concatenate([
#     np.logspace(0.01, 0, start_pi_shift_at // step_hypers_at),
#     np.logspace(0, -1.0, (n_iter - start_pi_shift_at) // step_hypers_at),
# ]).astype('float32')
# rho_hyper_schedule = np.logspace(0, -2, num=n_iter // step_hypers_at).astype('float32')
# gamma_hyper_schedule = np.logspace(0, -2, num=n_iter // step_hypers_at).astype('float32')
# plt.scatter(rho_hyper_schedule, pi_hyper_schedule, c=np.linspace(0, 1, num=pi_hyper_schedule.shape[0]))
# plt.yscale('log')
# plt.xscale('log')

history = []

In [None]:
pbar = tqdm(range(n_iter))
for i in pbar:
#     pi_hyper_fit.data = torch.tensor(pi_hyper_schedule[i // step_hypers_at])
#     rho_hyper_fit.data = torch.tensor(rho_hyper_schedule[i // step_hypers_at])
#     gamma_hyper_fit.data = torch.tensor(gamma_hyper_schedule[i // step_hypers_at])
#     pi_hyper_fit.data = torch.tensor(1.)
#     rho_hyper_fit.data = torch.tensor(1.)
#     gamma_hyper_fit.data = torch.tensor(1.)

    elbo = svi.step()
    
    if np.isnan(elbo):
        break

    # Fit tracking
    history.append(elbo)
    
    # Reporting/Breaking
    if (i % 1 == 0):
        if i > 1:
            pbar.set_postfix({
                'ELBO': history[-1],
                'delta': history[-2] - history[-1],
#                 'pi_hyper': pi_hyper_fit,
#                 'rho_hyper': rho_hyper_fit,
#                 'gamma_hyper': gamma_hyper_fit,
            })

In [None]:
plt.plot(history)

In [None]:
est = pyro.infer.Predictive(model1_fit, guide=_guide, num_samples=1)()
est = {k: est[k].detach().cpu().numpy().squeeze() for k in sim.keys()}

In [None]:
plt.plot(np.sort(est['rho']))

In [None]:
sns.heatmap(est['pi'])

In [None]:
sns.heatmap(est['gamma'].T)

In [None]:
plt.plot(np.sort(est['rho']))