In [None]:
import pandas as pd
from lib.util import info, idxwhere
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import scipy as sp

import pyro
import pyro.distributions as dist
import torch
from functools import partial
import arviz as az
from pyro.ops.contract import einsum
import seaborn as sns
from tqdm import tqdm

def rss(x, y):
    return np.sqrt(np.sum((x - y)**2))

In [None]:
inpath = 'data/core/100022/gtpro.read_r1.tsv.bz2'

data = (
    pd.read_table(
        inpath,
        names=[
            "library_id",
            "species_id",
            "snp_idx",
            "_3",
            "_4",
            "_5",
            "_6",
            "ref",
            "alt",
        ],
        index_col=["library_id", "species_id", "snp_idx"],
    )[["ref", "alt"]]
    .rename_axis(columns="allele")
    .stack()
    .to_xarray().fillna(0).astype(int).squeeze()
)
info(data.sizes)

cvrg = data.sum('allele')

In [None]:
def model(
    s,
    m,
    y=None,
    gamma_hyper=torch.tensor(0.).double(),
    pi0=torch.tensor(1.).double(),
    rho0=torch.tensor(1.).double(),
    epsilon0=torch.tensor(0.01).double(),
    alpha0=torch.tensor(1000.).double(),
):
    
    n, g = m.shape
    
    with pyro.plate('position', g, dim=-1):
        with pyro.plate('strain', s, dim=-2):
            gamma = pyro.sample(
                'gamma', dist.Beta(torch.exp(-gamma_hyper), torch.exp(-gamma_hyper))
            )
#    assert gamma.shape[-2:] == torch.Size([s, g])
    
    rho_hyper = pyro.sample('rho_hyper', dist.Gamma(rho0, 1.))
    rho = pyro.sample('rho', dist.Dirichlet(torch.ones(s).double() * rho_hyper))
    
    epsilon_hyper = pyro.sample('epsilon_hyper', dist.Beta(1., 1 / epsilon0))
    alpha_hyper = pyro.sample('alpha_hyper', dist.Gamma(alpha0, 1.))
    
    pi_hyper = pyro.sample('pi_hyper', dist.Gamma(pi0, 1.))
    
    with pyro.plate('sample', n, dim=-1):
        pi = pyro.sample('pi', dist.Dirichlet(rho * pi_hyper))
#    assert pi.shape[-2:] == torch.Size([n, s])
        alpha = pyro.sample('alpha', dist.Gamma(alpha_hyper, 1.)).unsqueeze(-1)
        epsilon = pyro.sample('epsilon', dist.Beta(1., 1 / epsilon_hyper)).unsqueeze(-1)

    p_noerr = pyro.deterministic('p_noerr', pi @ gamma)
    p = pyro.deterministic('p',
        (1 - epsilon / 2) * (p_noerr) +
        (epsilon / 2) * (1 - p_noerr)
    )
#    assert p.shape[-2:] == torch.Size([n, g])

#     # Mini-batch indexing
#     batch_p = p  # pyro.ops.indexing.Vindex(p)[..., batch_ii, :][..., batch_jj]
#     batch_m = pyro.ops.indexing.Vindex(m)[..., batch_ii, :][..., batch_jj]
#     if y is not None:
#         batch_y = pyro.ops.indexing.Vindex(y)[..., batch_ii, :][..., batch_jj]
#     else:
#         batch_y = None
        
    y = pyro.sample(
        'y',
        dist.BetaBinomial(
            concentration1=alpha * p,
            concentration0=alpha * (1 - p),
            total_count=m
        ),
        obs=y
    )
#    assert y.shape[-2:] == torch.Size([n, g])
    return y

In [None]:
# Samples with >25% of positions covered
high_cvrg_samples = (data.sum('allele') > 0).mean('snp_idx') > 0.25
print(high_cvrg_samples.sum().values)

position_ss = np.random.randint(data.shape[1], size=1000)

In [None]:
_data = data[high_cvrg_samples, position_ss]
s = 400
m = torch.tensor(_data.sum('allele').values).double()
n, g = m.shape
y_obs = torch.tensor(_data.sel(allele='alt').values).double()

model_fit = partial(
    pyro.condition(
        model,
        data={
          'alpha_hyper': torch.tensor(100.).double(),
          'epsilon_hyper': torch.tensor(0.01).double(),
          'pi_hyper': torch.tensor(1e-5).double(),
          'rho_hyper': torch.tensor(1.0).double(),
#           'epsilon': torch.ones(n) * 0.001,
#           'rho': torch.ones(s).double() / s,
        }
    ),
    s=s,
    m=m,
    gamma_hyper=torch.tensor(20.).double(),
#     pi0=torch.tensor(1e-1).double(),
#    rho0=torch.tensor(1.),
#    alpha0=torch.tensor(100.),  # These two params have no effect IF we condition
#    epsilon0=torch.tensor(0.01),  #  on epsilon_hyper and alpha_hyper
)

trace = pyro.poutine.trace(model_fit).get_trace()
trace.compute_log_prob()
print(trace.format_shapes())

In [None]:
#_guide = partial(guide_conditioned, s=s, m=m)
#_guide = pyro.infer.autoguide.AutoDiagonalNormal(model_fit, )
#_guide = pyro.infer.autoguide.AutoNormal(model_fit, )
#_guide = pyro.infer.autoguide.AutoLowRankMultivariateNormal(model_fit, rank=100)
_guide = pyro.infer.autoguide.AutoLaplaceApproximation(model_fit)
#_guide = pyro.infer.autoguide.AutoIAFNormal(model_fit, hidden_dim=[500], num_transforms=3)
#_guide = pyro.infer.autoguide.AutoDelta(model_fit)

opt = pyro.optim.Adamax({"lr": 1e-1}, {"clip_norm": 100.})
#opt = pyro.optim.RMSprop({"lr": 0.001})

svi = pyro.infer.SVI(
    model_fit,
    _guide,
    opt,
    loss=pyro.infer.JitTrace_ELBO()
)

pyro.clear_param_store()

pbar = tqdm(range(10000))
history = []
delta_history = []
# trace_epsilon_interval = []
# trace_gamma_a = []
# trace_gamma_b = []
# trace_gamma_loc = []
# trace_alpha_log = []
# trace_pi_simplex = []
for i in pbar:
    elbo = svi.step(
        y=y_obs,
    )
    
    if np.isnan(elbo):
        break

    # Fit tracking
    history.append(elbo)
    
    # Reporting/Breaking
    if (i % 1 == 0):
        if i > 1:
            pbar.set_postfix({'ELBO': history[-1], 'delta': history[-2] - history[-1]})
#         trace_epsilon_interval.append(pyro.get_param_store()['epsilon_interval'].detach().numpy().copy())
#         trace_gamma_a.append(pyro.get_param_store()['gamma_a'].detach().numpy().copy())
#         trace_gamma_b.append(pyro.get_param_store()['gamma_b'].detach().numpy().copy())
# #         trace_gamma_loc.append(pyro.get_param_store()['gamma_loc'].detach().numpy().copy())
#         trace_alpha_log.append(pyro.get_param_store()['alpha_log'].detach().numpy().copy())
#         trace_pi_simplex.append(pyro.get_param_store()['pi_simplex'].detach().numpy().copy())
#     if np.mean(delta_history[-1000:]) < 0.0001:
#         break

        
pbar.refresh()

In [None]:
plt.plot(history)

In [None]:
svi_predictive = pyro.infer.Predictive(model_fit, guide=partial(_guide, s=s, m=m), num_samples=1)
svi_posterior = {k: v.detach().numpy()
                 for k, v
                 in svi_predictive(y=y_obs).items()}
#posterior_predictive = svi_predictive()['y']

#fit_pi = fit_pi.rename(columns=lambda i: f"fit_{i}")

In [None]:
pi_fit = pd.DataFrame(svi_posterior['pi'].mean(0).mean(0))
sns.clustermap(pi_fit)

In [None]:
plt.plot(pi_fit.max(1).sort_values(ascending=False).values)

In [None]:
plt.plot(pi_fit.max(0).sort_values(ascending=False).values)

In [None]:
gamma_fit = pd.DataFrame(svi_posterior['gamma'].squeeze())

sns.clustermap(gamma_fit.T)

In [None]:
frac_obs = y_obs.numpy() / m.numpy()

fig = plt.figure(figsize=(10, 10))
sns.heatmap(frac_obs[:,:], cmap='coolwarm', cbar=False)

In [None]:
frac_expect = (pi_fit @ gamma_fit) #* m.numpy()

fig = plt.figure(figsize=(10, 10))
sns.heatmap(frac_expect, cmap='coolwarm', cbar=False)

In [None]:
fig = plt.figure(figsize=(10, 10))

sns.heatmap(frac_obs - frac_expect, cmap='coolwarm')

np.abs(((frac_obs - frac_expect) * m.numpy())).sum().sum() / m.numpy().sum()

In [None]:
plt.hist(svi_posterior['alpha'].squeeze(), bins=50)
None

In [None]:
plt.hist(svi_posterior['epsilon'].squeeze(), bins=50)
None

In [None]:
gamma_fit

In [None]:
pi_fit_drop = pi_fit.loc[:, (pi_fit.max(0) > 0.01)]
gamma_fit_drop = gamma_fit.loc[(pi_fit.max(0) > 0.01), :]

In [None]:
_data2 = data[:, position_ss]
s2 = pi_fit_drop.shape[1]
m2 = torch.tensor(_data2.sum('allele').values).double()
n2, g2 = m2.shape
y_obs2 = torch.tensor(_data2.sel(allele='alt').values).double()

model_fit2 = partial(
    pyro.condition(
        model,
        data={
          'alpha_hyper': torch.tensor(100.).double(),
          'epsilon_hyper': torch.tensor(0.01).double(),
          'pi_hyper': torch.tensor(1e-5).double(),
          'rho_hyper': torch.tensor(1.0).double(),
          'gamma': torch.tensor(gamma_fit_drop.values).double(),
#           'epsilon': torch.ones(n) * 0.001,
#           'rho': torch.ones(s).double() / s,
        }
    ),
    s=s2,
    m=m2,
    gamma_hyper=torch.tensor(20.).double(),
#     pi0=torch.tensor(1e-1).double(),
#    rho0=torch.tensor(1.),
#    alpha0=torch.tensor(100.),  # These two params have no effect IF we condition
#    epsilon0=torch.tensor(0.01),  #  on epsilon_hyper and alpha_hyper
)

trace2 = pyro.poutine.trace(model_fit2).get_trace()
trace2.compute_log_prob()
print(trace2.format_shapes())

In [None]:
#_guide2 = partial(guide_conditioned, s=s2, m=m2)
#_guide2 = pyro.infer.autoguide.AutoDiagonalNormal(model_fit2, )
#_guide2 = pyro.infer.autoguide.AutoNormal(model_fit2, )
#_guide2 = pyro.infer.autoguide.AutoLowRankMultivariateNormal(model_fit2, rank=100)
_guide2 = pyro.infer.autoguide.AutoLaplaceApproximation(model_fit2)
#_guide2 = pyro.infer.autoguide.AutoIAFNormal(model_fit2, hidden_dim=[500], num_transforms=3)
#_guide2 = pyro.infer.autoguide.AutoDelta(model_fit2)

opt = pyro.optim.Adamax({"lr": 1e-1}, {"clip_norm": 100.})
#opt = pyro.optim.RMSprop({"lr": 0.001})

svi2 = pyro.infer.SVI(
    model_fit2,
    _guide2,
    opt,
    loss=pyro.infer.JitTrace_ELBO()
)

pyro.clear_param_store()

pbar = tqdm(range(10000))
history2 = []
for i in pbar:
    elbo = svi2.step(
        y=y_obs2,
    )
    
    if np.isnan(elbo):
        break

    # Fit tracking
    history2.append(elbo)
    
    # Reporting/Breaking
    if (i % 1 == 0):
        if i > 1:
            pbar.set_postfix({'ELBO': history2[-1], 'delta': history2[-2] - history2[-1]})
#         trace_epsilon_interval.append(pyro.get_param_store()['epsilon_interval'].detach().numpy().copy())
#         trace_gamma_a.append(pyro.get_param_store()['gamma_a'].detach().numpy().copy())
#         trace_gamma_b.append(pyro.get_param_store()['gamma_b'].detach().numpy().copy())
# #         trace_gamma_loc.append(pyro.get_param_store()['gamma_loc'].detach().numpy().copy())
#         trace_alpha_log.append(pyro.get_param_store()['alpha_log'].detach().numpy().copy())
#         trace_pi_simplex.append(pyro.get_param_store()['pi_simplex'].detach().numpy().copy())
#     if np.mean(delta_history[-1000:]) < 0.0001:
#         break

        
pbar.refresh()

In [None]:
svi_predictive2 = pyro.infer.Predictive(model_fit2, guide=partial(_guide2, s=s2, m=m2), num_samples=1)
svi_posterior2 = {k: v.detach().numpy()
                 for k, v
                 in svi_predictive2(y=y_obs2).items()}
#posterior_predictive = svi_predictive()['y']

#fit_pi = fit_pi.rename(columns=lambda i: f"fit_{i}")

In [None]:
plt.hist(svi_posterior2['pi'].squeeze().max(1))

In [None]:
sns.clustermap(gamma_fit_drop.T)

In [None]:
pi_fit = pd.DataFrame(svi_posterior['pi'].mean(0).mean(0))
sns.clustermap(pi_fit)

In [None]:
pi_fit2 = pd.DataFrame(svi_posterior2['pi'].mean(0).mean(0))

In [None]:
plt.scatter(
    pi_fit.max(1).values,
    pi_fit2.max(1)[high_cvrg_samples].values
)

In [None]:
sns.heatcmap(pi_fit.values)

In [None]:
sns.heatmap(pi_fit2.loc[list(high_cvrg_samples)].values)

In [None]:
plt.plot(pi_fit2.max(1).sort_values(ascending=False).values)

In [None]:
plt.plot(pi_fit2.max(0).sort_values(ascending=False).values)

In [None]:
plt.scatter(cvrg.mean('snp_idx'), pi_fit2.max(1), s=2)
plt.xscale('log')

In [None]:
gamma_fit2