# Preamble

In [None]:
import pandas as pd
from lib.util import info, idxwhere
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import scipy as sp

import pyro
import pyro.distributions as dist
import torch
from functools import partial
#import arviz as az
from pyro.ops.contract import einsum
from tqdm import tqdm


def qqplot(x, y, **kwargs):
    x = np.sort(x)
    y = np.sort(y)
    _min = min(x[0], y[0])
    _max = max(x[-1], y[-1])
    
    ax = plt.gca()
    
    _kwargs = dict(marker='.', alpha=0.5)
    _kwargs.update(kwargs)
    ax.scatter(x, y, **_kwargs)
    ax.plot([_min, _max], [_min, _max], lw=1, linestyle='--', color='k')
    return ax

def binary_entropy(p, normalize=False, axis=None):
    q = 1 - p
    ent = np.sum(-(p * np.log2(p) + q * np.log2(q)), axis=axis)
    if normalize:
        ent = ent / p.shape[axis]
    return ent

def binary_entropy_counts(y, m, normalize=False, axis=None):
    p = ((y + 1) / (m + 2))
    return binary_entropy(p, normalize=False, axis=axis)

def entropy(p, axis=None):
    ent = -(p * np.log(p))
    return np.sum(ent, axis=axis)

In [None]:
(
    binary_entropy(np.array([[0.1, 0.9, 0.9], [0.1, 0.1, 0.1]]), axis=1, normalize=False),
    binary_entropy(np.array([[0.1, 0.9, 0.9], [0.1, 0.1, 0.1]]), axis=0, normalize=False),
    binary_entropy(np.array([[0.1, 0.9, 0.9], [0.1, 0.1, 0.1]]), axis=1, normalize=True),
    binary_entropy(np.array([[0.1, 0.9, 0.9], [0.1, 0.1, 0.1]]), axis=0, normalize=True),
)

In [None]:
def NegativeBinomialReparam(mu, r):
    p = 1 / ((r / mu) + 1)
    return dist.NegativeBinomial(
        total_count=r,
        probs=p
    )

def as_torch(x, dtype=None, device=None):
    # Cast inputs and set device
    return torch.tensor(x, dtype=dtype, device=device)

def as_torch_all(dtype=None, device=None, **kwargs):
    # Cast inputs and set device
    return {k: as_torch(kwargs[k], dtype=dtype, device=device) for k in kwargs}

In [None]:
as_torch(1.0, device='cuda')

# Model0: Dirichlet

In [None]:
def model0(
    n,
    g,
    s,
    gamma_hyper=as_torch(1.),
    rho_hyper=as_torch(1.),
    pi_hyper=as_torch(1.),
    m_hyper_mu=as_torch(10.),
    m_hyper_r=as_torch(1.),
    epsilon_hyper=as_torch(0.01),
    alpha_hyper=as_torch(100.),
    dtype=None,
    device=None,
):
    
    with pyro.plate('position', g, dim=-1):
        with pyro.plate('strain', s, dim=-2):
            gamma = pyro.sample(
                'gamma', dist.Beta(gamma_hyper, gamma_hyper)
            )
    
#     rho_ = pyro.sample('rho_', dist.LogNormal(0, 1 / rho_hyper).expand([s]).to_event())
#     rho = pyro.deterministic('rho', rho_ / rho_.sum())
    rho = pyro.sample('rho', dist.Dirichlet(torch.ones(s, dtype=dtype, device=device) * rho_hyper))
    
    with pyro.plate('sample', n, dim=-1):
        pi = pyro.sample('pi', dist.Dirichlet(rho * pi_hyper))
        alpha = pyro.sample('alpha', dist.Gamma(alpha_hyper, 1.)).unsqueeze(-1)
        epsilon = pyro.sample('epsilon', dist.Beta(1., 1 / epsilon_hyper)).unsqueeze(-1)
        
    m = pyro.sample('m', NegativeBinomialReparam(m_hyper_mu, m_hyper_r).expand([n, g]).to_event())

    p_noerr = pyro.deterministic('p_noerr', pi @ gamma)
    p = pyro.deterministic('p',
        (1 - epsilon / 2) * (p_noerr) +
        (epsilon / 2) * (1 - p_noerr)
    )
        
    y = pyro.sample(
        'y',
        dist.BetaBinomial(
            concentration1=alpha * p,
            concentration0=alpha * (1 - p),
            total_count=m
        ).to_event(),
    )
    return y

In [None]:
n, g, s = 1000, 500, 400

model0_sim = partial(
    pyro.condition(
        model0,
        data=as_torch_all(
            # NOTHING HERE
            dtype=torch.float32,
            device="cuda",
        ),
    ),
    s=s,
    g=g,
    n=n,
    **as_torch_all(
        gamma_hyper=0.01,
        pi_hyper=0.0005,
        rho_hyper=1.,
        m_hyper_mu=10.,
        m_hyper_r=1.,
        epsilon_hyper=0.01,
        alpha_hyper=100.,
        dtype=torch.float32,
        device="cuda",
    ),
    dtype=torch.float32,
    device="cuda",
)

trace = pyro.poutine.trace(model0_sim).get_trace()
trace.compute_log_prob()
print(trace.format_shapes())

In [None]:
sim0 = pyro.infer.Predictive(model0_sim, num_samples=1)()
sim0 = {k: sim0[k].detach().cpu().numpy().squeeze() for k in sim0.keys()}

In [None]:
sns.heatmap(sim0['pi'])

In [None]:
plt.plot(np.sort(sim0['pi'].max(0)))

In [None]:
plt.plot(np.sort(sim0['pi'].max(1)))

In [None]:
plt.plot(np.sort(sim0['rho']))

In [None]:
sns.heatmap(sim0['gamma'])

In [None]:
model0_fit = partial(
    pyro.condition(
        model0,
        data=as_torch_all(
            m=sim0['m'], y=sim0['y'],
            dtype=torch.float32,
            device="cuda",
        ),
    ),
    s=s,
    g=g,
    n=n,
    **as_torch_all(
        gamma_hyper=0.1,
        pi_hyper=1.,
        rho_hyper=1.,
        m_hyper_mu=10.,  # Conditioned out
        m_hyper_r=1.,  # Conditioned out
        epsilon_hyper=0.01,
        alpha_hyper=100.,
        dtype=torch.float32,
        device="cuda",
    ),
    dtype=torch.float32,
    device="cuda",
)

_guide = pyro.infer.autoguide.AutoLaplaceApproximation(model0_fit)
# _guide = pyro.infer.autoguide.AutoNormal(model0_fit)

opt = pyro.optim.Adamax({"lr": 1e-1}, {"clip_norm": 100.})
svi = pyro.infer.SVI(
    model0_fit,
    _guide,
    opt,
    loss=pyro.infer.JitTrace_ELBO()
)
pyro.clear_param_store()

n_iter = int(5e2)

history = []

In [None]:
pbar = tqdm(range(n_iter))
for i in pbar:
    elbo = svi.step()
    
    if np.isnan(elbo):
        break

    # Fit tracking
    history.append(elbo)
    
    # Reporting/Breaking
    if (i % 1 == 0):
        if i > 1:
            pbar.set_postfix({
                'ELBO': history[-1],
                'delta': history[-2] - history[-1],
#                 'pi_hyper': pi_hyper_fit,
#                 'rho_hyper': rho_hyper_fit,
#                 'gamma_hyper': gamma_hyper_fit,
            })

In [None]:
plt.plot(history)

In [None]:
est0 = pyro.infer.Predictive(model0_fit, guide=_guide, num_samples=100)()
est0 = {k: est0[k].detach().cpu().numpy().mean(0).squeeze() for k in est0.keys()}

In [None]:
sns.heatmap(est0['pi'])

In [None]:
sns.clustermap(est0['gamma'].T)

In [None]:
plt.plot(np.sort(sim0['rho']), label='true_metacommunity')
plt.plot(np.sort(sim0['pi'].mean(0)), label='true_mean')
plt.plot(np.sort(est0['rho']), label='fit_metacommunity')
plt.plot(np.sort(est0['pi'].mean(0)), label='fit_mean')
plt.legend()

# Model1: Gumbel-Softmax

In [None]:
def model1(
    n,
    g,
    s,
    gamma_hyper=as_torch(1.),
    rho_hyper=as_torch(1.),
    pi_hyper=as_torch(1.),
    m_hyper_mu=as_torch(10.),
    m_hyper_r=as_torch(1.),
    epsilon_hyper=as_torch(0.01),
    alpha_hyper=as_torch(100.),
    dtype=None,
    device=None,
):
    
    with pyro.plate('position', g, dim=-1):
        with pyro.plate('strain', s, dim=-2):
            gamma = pyro.sample(
                'gamma', dist.RelaxedBernoulli(temperature=gamma_hyper, probs=torch.tensor(0.5, dtype=dtype, device=device))
            )
    
#     rho_ = pyro.sample('rho_', dist.LogNormal(0, 1 / rho_hyper).expand([s]).to_event())
#     rho = pyro.deterministic('rho', rho_ / rho_.sum())
    rho = pyro.sample('rho', dist.RelaxedOneHotCategorical(temperature=rho_hyper, logits=torch.zeros(s, dtype=dtype, device=device)))
    
    with pyro.plate('sample', n, dim=-1):
        pi = pyro.sample('pi', dist.RelaxedOneHotCategorical(temperature=pi_hyper, probs=rho))
        epsilon = pyro.sample('epsilon', dist.Beta(1., 1 / epsilon_hyper)).unsqueeze(-1)
        alpha = pyro.sample('alpha', dist.Gamma(alpha_hyper, 1.)).unsqueeze(-1)

        
    m = pyro.sample('m', NegativeBinomialReparam(m_hyper_mu, m_hyper_r).expand([n, g]).to_event())

    p_noerr = pyro.deterministic('p_noerr', pi @ gamma)
    p = pyro.deterministic('p',
        (1 - epsilon / 2) * (p_noerr) +
        (epsilon / 2) * (1 - p_noerr)
    )
        
    y = pyro.sample(
        'y',
        dist.BetaBinomial(
            concentration1=alpha * p,
            concentration0=alpha * (1 - p),
            total_count=m
        ).to_event(),
#         dist.Binomial(
#             probs=p,
#             total_count=m
#         ).to_event(),
    )
    return y

In [None]:
model1_sim = partial(
    pyro.condition(
        model1,
        data=as_torch_all(
            # NOTHING HERE
            dtype=torch.float32,
            device="cuda",
        ),
    ),
    s=s,
    g=g,
    n=n,
    **as_torch_all(
        gamma_hyper=0.001,
        pi_hyper=0.2,
        rho_hyper=2.,
        m_hyper_mu=10.,
        m_hyper_r=1.,
        epsilon_hyper=0.01,
        alpha_hyper=100.,
        dtype=torch.float32,
        device="cuda",
    ),
    dtype=torch.float32,
    device="cuda",
)

trace = pyro.poutine.trace(model1_sim).get_trace()
trace.compute_log_prob()
print(trace.format_shapes())

In [None]:
# Not run so the same simulated data is fit for each model.

sim1 = pyro.infer.Predictive(model1_sim, num_samples=1)()
sim1 = {k: sim1[k].detach().cpu().numpy().squeeze() for k in sim1.keys()}

In [None]:
plt.plot(np.sort(sim1['pi'].max(0)))

In [None]:
plt.scatter(sim1['pi'].sum(0), sim1['pi'].max(0))
plt.xscale('log')

In [None]:
plt.plot(np.sort(sim1['pi'].max(1)))

In [None]:
plt.plot(np.sort(sim1['rho']))
plt.plot(np.sort(sim1['pi'].mean(0)))
#plt.plot(np.sort(sim1['pi'].max(0)))

In [None]:
sns.heatmap(sim1['pi'])

In [None]:
sns.clustermap(sim1['gamma'])

In [None]:
sns.clustermap(sim1['p'])

## Estimation

### Pre-Clustering

In [None]:
from sklearn.cluster import AgglomerativeClustering
from scipy.spatial.distance import pdist, squareform

p_obs = (sim1['y'] + 1) / (sim1['m'] + 2)

genotype_score = p_obs * 2 - 1
agg = AgglomerativeClustering(n_clusters=None, affinity='cosine', linkage='complete', distance_threshold=0.05).fit(genotype_score)
clust = pd.Series(agg.labels_)
clust.value_counts()

In [None]:
y_total = pd.DataFrame(sim1['y']).groupby(clust).sum()
m_total = pd.DataFrame(sim1['m']).groupby(clust).sum()
clust_genotype = (y_total + 1) / (m_total + 2)

additional_haplotypes = 100
gamma_init = pd.concat([
    clust_genotype, pd.DataFrame(np.ones((additional_haplotypes, clust_genotype.shape[1])) * 0.5)
]).reset_index(drop=True)
sns.clustermap(gamma_init)
#sns.clustermap(clust_genotype)

In [None]:
clust

In [None]:
s_fit = gamma_init.shape[0]
pi_init = np.ones((n, s_fit))
for i in range(s_fit):
    pi_init[i, clust[i]] = (s_fit - 1)
pi_init /= pi_init.sum(1, keepdims=True)
pi_init

### Estimation model

In [None]:
model1_fit = partial(
    pyro.condition(
        model1,
        data=as_torch_all(
            m=sim1['m'], y=sim1['y'],
            dtype=torch.float32,
            device="cuda",
        ),
    ),
    s=s_fit,
    g=g,
    n=n,
    **as_torch_all(
        gamma_hyper=1.,
        pi_hyper=1.,
        rho_hyper=1.,
        m_hyper_mu=10.,  # Conditioned out
        m_hyper_r=1.,  # Conditioned out
        epsilon_hyper=0.01,
        alpha_hyper=200.,
        dtype=torch.float32,
        device="cuda",
    ),
    dtype=torch.float32,
    device="cuda",
)

_guide = pyro.infer.autoguide.AutoLaplaceApproximation(
    model1_fit,
    init_loc_fn=pyro.infer.autoguide.initialization.init_to_value(
        values={
            'gamma': torch.tensor(gamma_init.values, dtype=torch.float32, device="cuda"),
            'pi': torch.tensor(pi_init, dtype=torch.float32, device="cuda"),
            'rho': torch.tensor(np.ones(s_fit) / s_fit, dtype=torch.float32, device="cuda"),
            'alpha': torch.tensor(200. * np.ones(n), dtype=torch.float32, device="cuda"),
        }
    ),
)
# _guide = pyro.infer.autoguide.AutoNormal(model1_fit)
opt = pyro.optim.Adamax({"lr": 1e-1}, {"clip_norm": 100.})
svi = pyro.infer.SVI(
    model1_fit,
    _guide,
    opt,
    loss=pyro.infer.JitTrace_ELBO()
)
pyro.clear_param_store()

n_iter = int(2e3)

history = []

### Gradient descent

In [None]:
pbar = tqdm(range(n_iter))
for i in pbar:
    elbo = svi.step()
    
    if np.isnan(elbo):
        break

    # Fit tracking
    history.append(elbo)
    
    # Reporting/Breaking
    if (i % 1 == 0):
        if i > 1:
            pbar.set_postfix({
                'ELBO': history[-1],
                'delta': history[-2] - history[-1],
#                 'pi_hyper': pi_hyper_fit,
#                 'rho_hyper': rho_hyper_fit,
#                 'gamma_hyper': gamma_hyper_fit,
            })

In [None]:
plt.plot(history)

In [None]:
est1 = pyro.infer.Predictive(model1_fit, guide=_guide, num_samples=1)()
est1 = {k: est1[k].detach().cpu().numpy().mean(0).squeeze() for k in est1.keys()}

### Diagnostics

In [None]:
# If alpha collapses to near 0 in the fit, then
# the model needs to be re-adjusted.

qqplot(sim1['alpha'], est1['alpha'])

In [None]:
qqplot(sim1['epsilon'], est1['epsilon'])

In [None]:
# Q: Why do we have accuracy collapse?

plt.scatter(sim1['alpha'], sim1['epsilon'])
plt.scatter(est1['alpha'], est1['epsilon'])

In [None]:
plt.scatter(gamma_init.values[0], est1['gamma'][0])

In [None]:
plt.scatter(pi_init[0], est1['pi'][0])

### Comparison to ground truth

In [None]:
# position_error = np.sqrt((((est1['p_noerr'] - sim1['p_noerr'])**2).mean(0)))
fit_error = np.sqrt((((est1['p_noerr'] - sim1['p_noerr'])**2)).mean(1))
observation_error = np.sqrt((((sim1['p_noerr'] - np.nan_to_num(sim1['y'] / sim1['m'], nan=0.5))**2).mean(1)))
prediction_error = np.sqrt((((est1['p_noerr'] - np.nan_to_num(sim1['y'] / sim1['m'], nan=0.5))**2).mean(1)))


bins = np.linspace(0, 1, num=50)
#plt.hist(position_error, bins=bins, label='position', alpha=0.75)
plt.hist(fit_error, bins=bins, label='fit', alpha=0.5)
plt.hist(observation_error, bins=bins, label='observation', alpha=0.5)
plt.hist(prediction_error, bins=bins, label='prediction', alpha=0.5)

plt.legend()

In [None]:
fig, axs = plt.subplots(3, figsize=(5, 10))
art0 = axs[0].scatter(prediction_error, fit_error, c=sim1['pi'].max(1), marker='.')
art1 = axs[1].scatter(prediction_error, fit_error, c=est1['pi'].max(1), marker='.')
# art1 = axs[1].scatter(prediction_error, sample_error, c=est1['alpha'], marker='.')
art2 = axs[2].scatter(prediction_error, fit_error, c=sim1['epsilon'], marker='.')
fig.colorbar(art0, ax=axs[0])
fig.colorbar(art1, ax=axs[1])
fig.colorbar(art2, ax=axs[2])

In [None]:
from scipy.spatial.distance import cdist, pdist
import pandas as pd

strain_dist = pd.DataFrame(cdist(sim1['gamma'] * 2 - 1, est1['gamma'] * 2 - 1, metric='cosine'))

best_dist = strain_dist.min(1)

plt.scatter(sim1['pi'].sum(0), best_dist, c=sim1['pi'].max(0), marker='.')
plt.xscale('log')
plt.yscale('log')
plt.ylim(top=1e-0)

In [None]:
plt.scatter(sim1['pi'].max(1), est1['pi'].max(1), marker='.', c=sim1['epsilon'])

In [None]:
# bins = np.linspace(0, 1, num=50)
# plt.hist(position_error, bins=bins, label='position', alpha=0.75)
# plt.hist(sample_error, bins=bins, label='sample', alpha=0.75)
# plt.legend()

plt.scatter(observation_error, prediction_error, marker='.', alpha=0.5, c=sim1['alpha'])

In [None]:
plt.scatter(est1['epsilon'], fit_error, marker='.', alpha=0.5, c=sim1['epsilon'])
plt.colorbar()

In [None]:
plt.scatter(sim1['pi'].max(1), fit_error, marker='.', alpha=0.5, c=est1['alpha'])
plt.colorbar()

In [None]:
plt.hist(1 - pdist(est1['gamma'].T, metric='correlation'), bins=np.linspace(-1, 1, num=100))
None

In [None]:
sns.clustermap(est1['pi'])

In [None]:
sns.clustermap(est1['gamma'].T)

In [None]:
# TODO: Merge close strains
# TODO: How accurate are the strain estimates before/after merging?

# Model2: Hybrid

In [None]:
def model2(
    n,
    g,
    s,
    gamma_hyper=as_torch(1.),
    rho_hyper=as_torch(1.),
    pi_hyper=as_torch(1.),
    m_hyper_mu=as_torch(10.),
    m_hyper_r=as_torch(1.),
    epsilon_hyper=as_torch(0.01),
    alpha_hyper=as_torch(100.),
    dtype=None,
    device=None,
):
    
    with pyro.plate('position', g, dim=-1):
        with pyro.plate('strain', s, dim=-2):
            gamma = pyro.sample(
                'gamma', dist.RelaxedBernoulli(temperature=gamma_hyper, logits=torch.tensor(0, dtype=dtype, device=device))
            )
    
#     rho_ = pyro.sample('rho_', dist.LogNormal(0, 1 / rho_hyper).expand([s]).to_event())
#     rho = pyro.deterministic('rho', rho_ / rho_.sum())
#     rho = pyro.sample('rho', dist.Dirichlet(torch.ones(s, dtype=dtype, device=device) * rho_hyper))
    rho = pyro.sample('rho', dist.RelaxedOneHotCategorical(temperature=rho_hyper, logits=torch.zeros(s, dtype=dtype, device=device)))
    
    with pyro.plate('sample', n, dim=-1):
#         pi = pyro.sample('pi', dist.Dirichlet(rho * pi_hyper))
        pi = pyro.sample('pi', dist.RelaxedOneHotCategorical(temperature=pi_hyper, probs=rho))
        alpha = pyro.sample('alpha', dist.Gamma(alpha_hyper, 1.)).unsqueeze(-1)
        epsilon = pyro.sample('epsilon', dist.Beta(1., 1 / epsilon_hyper)).unsqueeze(-1)
        
    m = pyro.sample('m', NegativeBinomialReparam(m_hyper_mu, m_hyper_r).expand([n, g]).to_event())

    p_noerr = pyro.deterministic('p_noerr', pi @ gamma)
    p = pyro.deterministic('p',
        (1 - epsilon / 2) * (p_noerr) +
        (epsilon / 2) * (1 - p_noerr)
    )
        
    y = pyro.sample(
        'y',
        dist.BetaBinomial(
            concentration1=alpha * p,
            concentration0=alpha * (1 - p),
            total_count=m
        ).to_event(),
    )
    return y

# def model2(
#     n,
#     g,
#     s,
#     gamma_hyper=as_torch(1.),
#     rho_hyper=as_torch(1.),
#     pi_hyper=as_torch(1.),
#     m_hyper_mu=as_torch(10.),
#     m_hyper_r=as_torch(1.),
#     epsilon_hyper=as_torch(0.01),
#     alpha_hyper=as_torch(100.),
#     dtype=None,
#     device=None,
# ):
    
#     with pyro.plate('position', g, dim=-1):
#         with pyro.plate('strain', s, dim=-2):
#             gamma = pyro.sample(
#                 'gamma', dist.RelaxedBernoulli(temperature=gamma_hyper, probs=torch.tensor(0.5, dtype=dtype, device=device))
#             )
    
# #     rho_ = pyro.sample('rho_', dist.LogNormal(0, 1 / rho_hyper).expand([s]).to_event())
# #     rho = pyro.deterministic('rho', rho_ / rho_.sum())
#     rho = pyro.sample('rho', dist.RelaxedOneHotCategorical(temperature=rho_hyper, logits=torch.zeros(s, dtype=dtype, device=device)))
    
#     with pyro.plate('sample', n, dim=-1):
#         pi = pyro.sample('pi', dist.RelaxedOneHotCategorical(temperature=pi_hyper, probs=rho))
#         epsilon = pyro.sample('epsilon', dist.Beta(1., 1 / epsilon_hyper)).unsqueeze(-1)
#         alpha = pyro.sample('alpha', dist.Gamma(alpha_hyper, 1.)).unsqueeze(-1)

        
#     m = pyro.sample('m', NegativeBinomialReparam(m_hyper_mu, m_hyper_r).expand([n, g]).to_event())

#     p_noerr = pyro.deterministic('p_noerr', pi @ gamma)
#     p = pyro.deterministic('p',
#         (1 - epsilon / 2) * (p_noerr) +
#         (epsilon / 2) * (1 - p_noerr)
#     )
        
#     y = pyro.sample(
#         'y',
#         dist.BetaBinomial(
#             concentration1=alpha * p,
#             concentration0=alpha * (1 - p),
#             total_count=m
#         ).to_event(),
# #         dist.Binomial(
# #             probs=p,
# #             total_count=m
# #         ).to_event(),
#     )
#     return y

## Simulate

In [None]:
n, g, s = 200, 500, 100

model2_sim = partial(
    pyro.condition(
        model2,
        data=as_torch_all(
            # NOTHING HERE
            dtype=torch.float32,
            device="cuda",
        ),
    ),
    s=s,
    g=g,
    n=n,
    **as_torch_all(
        gamma_hyper=0.001,
        pi_hyper=0.3,
        rho_hyper=3.,
        m_hyper_mu=2.,
        m_hyper_r=10.,
        epsilon_hyper=0.01,
        alpha_hyper=100.,
        dtype=torch.float32,
        device="cuda",
    ),
    dtype=torch.float32,
    device="cuda",
)

trace = pyro.poutine.trace(model2_sim).get_trace()
trace.compute_log_prob()
print(trace.format_shapes())

In [None]:
# Not run so the same simulated data is fit for each model.

sim2 = pyro.infer.Predictive(model2_sim, num_samples=1)()
sim2 = {k: sim2[k].detach().cpu().numpy().squeeze() for k in sim2.keys()}

In [None]:
plt.plot(np.sort(sim2['rho'] * n))
plt.plot(np.sort(sim2['pi'].sum(0)))

In [None]:
plt.plot(np.sort(sim2['pi'].max(0)))

In [None]:
plt.scatter(sim2['pi'].sum(0), sim2['pi'].max(0))

In [None]:
plt.plot(np.sort(sim2['pi'].max(1)))

In [None]:
plt.scatter(sim2['pi'].max(1), entropy(sim2['pi'], axis=1), marker='.', alpha=0.5)

In [None]:
sns.clustermap(sim2['pi'])

In [None]:
sns.clustermap(sim2['gamma'] * 2 - 1, metric='cosine', vmin=0, vmax=1)

## Pre-cluster

In [None]:
from sklearn.metrics import pairwise_distances

def genotype_distance(x, y):
    x = x * 2 - 1
    y = y * 2 - 1
    weight = (x * y) ** 2
    dist = ((x - y) / 2) ** 2
    wmean_dist = (weight * dist).sum() / (weight.sum())
    return wmean_dist

p_obs = (sim2['y'] + 1) / (sim2['m'] + 2)
sample_genotype_dist_matrix = pairwise_distances(p_obs, metric=genotype_distance, n_jobs=2)

In [None]:
from sklearn.cluster import AgglomerativeClustering
from scipy.spatial.distance import pdist, squareform

#genotype_score = p_obs * 2 - 1

# print("Calculating distance matrix.")
# sample_genotype_dist_matrix = pdist(p_obs, metric=genotype_distance)
# print("Done calculating distance matrix.")
# print("Clustering samples.")
agg = AgglomerativeClustering(n_clusters=None, affinity='precomputed', linkage='complete', distance_threshold=0.05).fit(sample_genotype_dist_matrix)
print("Done clustering samples.")
clust = pd.Series(agg.labels_)
clust.value_counts()

#sns.clustermap(genotype_score, metric='cosine', row_colors=mpl.cm.viridis(clust.values / clust.max()))

In [None]:
y_total = pd.DataFrame(sim2['y']).groupby(clust).sum()
m_total = pd.DataFrame(sim2['m']).groupby(clust).sum()
clust_genotype = (y_total + 1) / (m_total + 2)

additional_haplotypes = 0  # clust_genotype.shape[0]  # Double the number of haplotypes (200% more) from clustering
gamma_init = pd.concat([
    clust_genotype, pd.DataFrame(np.ones((additional_haplotypes, clust_genotype.shape[1])) * 0.5)
]).reset_index(drop=True)
#sns.clustermap(gamma_init)

In [None]:
s_fit = gamma_init.shape[0]
pi_init = np.ones((n, s_fit))
for i in range(n):
    pi_init[i, clust[i]] = (s_fit - 1)
pi_init /= pi_init.sum(1, keepdims=True)

#sns.clustermap(pi_init)

In [None]:
s_fit

## Fit

In [None]:
g_fit = g
# s_fit = s

model2_fit = partial(
    pyro.condition(
        model2,
        data=as_torch_all(
            m=sim2['m'][:, :g_fit], y=sim2['y'][:, :g_fit],
#             alpha=np.ones(n)*1000,
            dtype=torch.float32,
            device="cuda",
        ),
    ),
    s=s_fit,
    g=g_fit,
    n=n,
    **as_torch_all(
#         # True values
#         gamma_hyper=0.001,
#         pi_hyper=0.3,
#         rho_hyper=3.,
        # Fitting values
        gamma_hyper=0.01,
        pi_hyper=0.5,
        rho_hyper=0.5,
        m_hyper_mu=10.,  # Conditioned out
        m_hyper_r=1.,  # Conditioned out
        epsilon_hyper=0.01,
        alpha_hyper=100.,
        dtype=torch.float32,
        device="cuda",
    ),
    dtype=torch.float32,
    device="cuda",
)

eps_adjust_probs = lambda x, eps=1e-5: (x + eps) / (x + eps).sum(-1, keepdims=True)

_guide = pyro.infer.autoguide.AutoLaplaceApproximation(
    model2_fit,
    init_loc_fn=pyro.infer.autoguide.initialization.init_to_value(
        values={
            # Smart-initialize
            'gamma': torch.tensor(gamma_init.values[:, :g_fit], dtype=torch.float32, device="cuda"),
            'pi': torch.tensor(pi_init, dtype=torch.float32, device="cuda"),
            'rho': torch.tensor(np.ones(s_fit) / s_fit, dtype=torch.float32, device="cuda"),
            'alpha': torch.tensor(10. * np.ones(n), dtype=torch.float32, device="cuda"),
            'epsilon': torch.tensor(1e-1 * np.ones(n), dtype=torch.float32, device="cuda"),
            # True-initialize
#             'gamma': torch.tensor(eps_adjust_probs(sim2['gamma']), dtype=torch.float32, device="cuda"),
#             'pi': torch.tensor(eps_adjust_probs(sim2['pi']), dtype=torch.float32, device="cuda"),
#             'rho': torch.tensor(eps_adjust_probs(sim2['rho']), dtype=torch.float32, device="cuda"),
#             'alpha': torch.tensor(sim2['alpha'], dtype=torch.float32, device="cuda"),
#             'epsilon': torch.tensor(sim2['epsilon'], dtype=torch.float32, device="cuda"),
        }
    ),
)
# _guide = pyro.infer.autoguide.AutoNormal(model2_fit)
opt = pyro.optim.Adamax({"lr": 1e-0}, {"clip_norm": 100.})
svi = pyro.infer.SVI(
    model2_fit,
    _guide,
    opt,
    loss=pyro.infer.JitTrace_ELBO()
)
pyro.clear_param_store()

history = []

In [None]:
n_iter = 10000
lag = 20

try:
    pbar = tqdm(range(n_iter))
    for i in pbar:
        elbo = svi.step()

        if np.isnan(elbo):
            break

        # Fit tracking
        history.append(elbo)

        # Reporting/Breaking
        if (i % 1 == 0):
            if i > lag:
                delta = history[-2] - history[-1]
                delta_lag = (history[-lag] - history[-1]) / lag
                if delta_lag < 0:
                    print("Converged")
                    break
                pbar.set_postfix({
                    'ELBO': history[-1],
                    'delta': delta,
                    f'lag{lag}': delta_lag,
                })
except KeyboardInterrupt:
    print("Interrupted")
    pass
finally:         
    plt.plot(history)
    est2 = pyro.infer.Predictive(model2_fit, guide=_guide, num_samples=1)()
    est2 = {k: est2[k].detach().cpu().numpy().mean(0).squeeze() for k in est2.keys()}

## Assess

In [None]:
# If alpha collapses to near 0 in the fit, then
# the model needs to be re-adjusted.

qqplot(sim2['alpha'], est2['alpha'], marker='.', alpha=0.5, c=sim2['pi'].max(1), vmin=0)

In [None]:
plt.plot(np.sort(sim2['rho'] * n)[::-1], label='true_rho', alpha=0.75)
plt.plot(np.sort(sim2['pi'].sum(0))[::-1], label='true_sum', alpha=0.75)
plt.plot(np.sort(est2['rho'] * n)[::-1], label='fit_rho', alpha=0.75)
plt.plot(np.sort(est2['pi'].sum(0))[::-1], label='fit_sum', alpha=0.75)
plt.axhline(0, color='k', lw=1, linestyle='--')

plt.legend()

In [None]:
prediction_error = np.sqrt((((est2['p_noerr'] - np.nan_to_num(sim2['y'][:,:g_fit] / sim2['m'][:,:g_fit], nan=0.5))**2).mean(1)))

sample_mean_genotype_entropy = (est2['pi'] @ np.expand_dims(binary_entropy(est2['gamma'], axis=1, normalize=True), 1)).squeeze()
plt.scatter(sample_mean_genotype_entropy, est2['alpha'], c=prediction_error, marker='.', alpha=0.5)
plt.colorbar()

In [None]:
overfit_sample = (sample_mean_genotype_entropy > 0.3)
underfit_sample = (est2['alpha'] < 50)

failed_samples = overfit_sample | underfit_sample
print(overfit_sample.sum(), underfit_sample.sum(), failed_samples.sum())

In [None]:
idxwhere(pd.Series(failed_samples))[:10]

In [None]:
# position_error = np.sqrt((((est2['p_noerr'] - sim2['p_noerr'])**2).mean(0)))
fit_error = np.sqrt((((est2['p_noerr'] - sim2['p_noerr'])**2)).mean(1))
observation_error = np.sqrt((((sim2['p_noerr'] - np.nan_to_num(sim2['y'] / sim2['m'], nan=0.5))**2).mean(1)))
observed_entropy = binary_entropy_counts(sim2['y'], sim2['m'], axis=1)

bins = np.linspace(0, 1, num=50)
#plt.hist(position_error, bins=bins, label='position', alpha=0.75)
plt.hist(fit_error, bins=bins, label='fit', alpha=0.5)
plt.hist(observation_error, bins=bins, label='observation', alpha=0.5)
plt.hist(prediction_error, bins=bins, label='prediction', alpha=0.5)

plt.legend()

In [None]:
max_abundance_per_strain = sim2['pi'].max(0)
max_strain_per_sample = sim2['pi'].argmax(1)
max_abundance_per_sample = sim2['pi'].max(1)
entropy_per_sample = entropy(sim2['pi'], axis=1)
total_abundance_per_strain = sim2['pi'].sum(0)


plt.scatter(total_abundance_per_strain[max_strain_per_sample], entropy_per_sample, marker='.', alpha=0.5, c=failed_samples)
# CONCLUSION: Samples dominated by strains that are not abundant overall are fit much less well.
# So are samples with high observation error!

In [None]:
plt.scatter(sample_mean_genotype_entropy, entropy(est2['pi'], axis=1), marker='.', c=failed_samples)

In [None]:
plt.scatter(observation_error, entropy_per_sample, marker='.', alpha=0.5, c=failed_samples)
# Observation error counter-intuitivily decreased for "failed" samples
# at high observation error (due to large epsilon or small alpha)
# since the bad fit has less effect.

In [None]:
plt.scatter(entropy_per_sample, fit_error, marker='.', alpha=0.5, c=failed_samples)
# Observation error counter-intuitivily decreased for "failed" samples
# at high observation error (due to large epsilon or small alpha)
# since the bad fit has less effect.

In [None]:
plt.scatter(
    total_abundance_per_strain[max_strain_per_sample],
    max_abundance_per_strain[max_strain_per_sample],
    marker='.',
    alpha=0.5,
    c=prediction_error
)
plt.xscale('log')
plt.colorbar()
# CONCLUSION: Badly fit samples are often dominated by low-total-abundance strains without any high abundance samples.

In [None]:
plt.scatter(entropy(sim2['pi'], axis=1), entropy(est2['pi'], axis=1), c=sample_mean_genotype_entropy, marker='.', alpha=0.5)
plt.colorbar()

In [None]:
sns.clustermap(est2['pi'])

In [None]:
sns.clustermap(est2['gamma'].T * 2 - 1, metric='cosine')

In [None]:
qqplot(sim2['epsilon'], est2['epsilon'])

In [None]:
# Q: Why do we have accuracy collapse?

plt.scatter(sim2['alpha'], sim2['epsilon'], marker='.', alpha=0.5)
plt.scatter(est2['alpha'], est2['epsilon'], marker='.', alpha=0.5)

In [None]:
plt.scatter(gamma_init.values[0], est2['gamma'][0], marker='.', alpha=0.5)

In [None]:
plt.scatter(pi_init[0], est2['pi'][0], marker='.', alpha=0.5)

In [None]:
fig, axs = plt.subplots(5, figsize=(5, 12))
art0 = axs[0].scatter(prediction_error, fit_error, c=sim2['pi'].max(1), marker='.')
art1 = axs[1].scatter(prediction_error, fit_error, c=est2['pi'].max(1), marker='.')
# art1 = axs[1].scatter(prediction_error, sample_error, c=est2['alpha'], marker='.')
art2 = axs[2].scatter(prediction_error, fit_error, c=sim2['epsilon'], marker='.')
art3 = axs[3].scatter(prediction_error, fit_error, c=sample_mean_genotype_entropy, marker='.')
art4 = axs[4].scatter(prediction_error, fit_error, c=failed_samples, marker='.')


fig.colorbar(art0, ax=axs[0])
fig.colorbar(art1, ax=axs[1])
fig.colorbar(art2, ax=axs[2])
fig.colorbar(art3, ax=axs[3])
fig.colorbar(art4, ax=axs[4])

### Quality of haplotype inferences

In [None]:
from scipy.spatial.distance import cdist, pdist
import pandas as pd


# Fit to true
strain_dist = pd.DataFrame(cdist(sim2['gamma'] * 2 - 1, est2['gamma'] * 2 - 1, metric='cosine'))
best_match = strain_dist.idxmin(1)
best_fit_dist = strain_dist.min(1)

# Init to true
strain_dist = pd.DataFrame(cdist(sim2['gamma'] * 2 - 1, gamma_init * 2 - 1, metric='cosine'))
best_init_dist = strain_dist.min(1)

strain_entropy = binary_entropy(est2['gamma'], normalize=True, axis=1)
best_match_entropy = strain_entropy[best_match.values]

plt.scatter(best_fit_dist, best_init_dist, c=best_match_entropy, marker='.', alpha=0.5, norm=mpl.colors.PowerNorm(1/5))
plt.plot([0, 1], [0, 1], color='k', lw=1, linestyle='--')
plt.colorbar()
plt.xscale('log')
plt.yscale('log')
# plt.ylim(top=1e-0)

In [None]:
plt.scatter(sim2['pi'].sum(0), best_fit_dist, c=best_match_entropy, marker='.', alpha=0.5, norm=mpl.colors.PowerNorm(1/5))
plt.colorbar()
plt.xscale('log')
plt.yscale('log')
plt.ylim(top=2)

In [None]:
plt.scatter(sim2['pi'].max(1), est2['pi'].max(1), marker='.', c=failed_samples, alpha=0.2)

In [None]:
# bins = np.linspace(0, 1, num=50)
# plt.hist(position_error, bins=bins, label='position', alpha=0.75)
# plt.hist(sample_error, bins=bins, label='sample', alpha=0.75)
# plt.legend()

plt.scatter(observation_error, prediction_error, marker='.', alpha=0.5, c=failed_samples)

In [None]:
plt.scatter(sim2['epsilon'], fit_error, marker='.', alpha=0.5, c=failed_samples)
plt.colorbar()

In [None]:
plt.scatter(sim2['pi'].max(1), fit_error, marker='.', alpha=0.5, c=failed_samples)
plt.colorbar()

In [None]:
plt.hist(1 - pdist(est2['gamma'].T, metric='correlation'), bins=np.linspace(-1, 1, num=100))
None

### Quality of composition inferences.

In [None]:
# A sense of how accurate taxon calling is, despite possible permutations
# in the fit relative to true.
# When two samples have highly similar communities in reality,
# they have highly similar communities in the fit.

sns.jointplot(
    pdist(sim2['pi'][~failed_samples]),
    pdist(est2['pi'][~failed_samples]),
    kind='hex',
   norm=mpl.colors.PowerNorm(1/5),
)