In [None]:
import pandas as pd
from lib.util import info, idxwhere
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import scipy as sp

import pyro
import pyro.distributions as dist
import torch
from functools import partial
import arviz as az
from pyro.ops.contract import einsum
import seaborn as sns
from tqdm import tqdm
import xarray as xr

def rss(x, y):
    return np.sqrt(np.sum((x - y)**2))

def binary_entropy(p):
    q = (1 - p)
    return -p * np.log2(p) - q * np.log2(q)

def plot_loss_history(loss_history):
    min_loss = loss_history.min()
    plt.plot(loss_history - min_loss)
    plt.title(f'+{min_loss}')
    plt.yscale('log')
    return plt.gca()

def mean_residual_count(expect_frac, obs_count, m):
    frac_obs = obs_count / m
    out = np.abs(((frac_obs - expect_frac)))
    out[np.isnan(out)] = 0
    return (out * m).sum() / m.sum()

In [None]:
def model(
    s,
    m,
    y=None,
    gamma_hyper=1.,
    pi0=1.,
    rho0=1.,
    epsilon0=0.01,
    alpha0=1000.,
    dtype='float32',
    device='cpu',
):
    
    m, gamma_hyper, pi0, rho0, epsilon0, alpha0 = [
        torch.tensor(v, dtype=dtype, device=device)
        for v in [m, gamma_hyper, pi0, rho0, epsilon0, alpha0]
    ]
    if y is not None:
        y = torch.tensor(y)
    
    n, g = m.shape
    
    with pyro.plate('position', g, dim=-1):
        with pyro.plate('strain', s, dim=-2):
            gamma = pyro.sample(
                'gamma', dist.Beta(gamma_hyper, gamma_hyper)
            )
    # gamma.shape == (s, g)
    
    rho_hyper = pyro.sample('rho_hyper', dist.Gamma(rho0, 1.))
    rho = pyro.sample('rho', dist.Dirichlet(torch.ones(s, dtype=dtype, device=device) * rho_hyper))
    
    epsilon_hyper = pyro.sample('epsilon_hyper', dist.Beta(1., 1 / epsilon0))
    alpha_hyper = pyro.sample('alpha_hyper', dist.Gamma(alpha0, 1.))
    pi_hyper = pyro.sample('pi_hyper', dist.Gamma(pi0, 1.))
    
    with pyro.plate('sample', n, dim=-1):
        pi = pyro.sample('pi', dist.Dirichlet(rho * s * pi_hyper))
        alpha = pyro.sample('alpha', dist.Gamma(alpha_hyper, 1.)).unsqueeze(-1)
        epsilon = pyro.sample('epsilon', dist.Beta(1., 1 / epsilon_hyper)).unsqueeze(-1) 
    # pi.shape == (n, s)
    # alpha.shape == epsilon.shape == (n,)

    p_noerr = pyro.deterministic('p_noerr', pi @ gamma)
    p = pyro.deterministic('p',
        (1 - epsilon / 2) * (p_noerr) +
        (epsilon / 2) * (1 - p_noerr)
    )
    # p.shape == (n, g)

        
    y = pyro.sample(
        'y',
        dist.BetaBinomial(
            concentration1=alpha * p,
            concentration0=alpha * (1 - p),
            total_count=m
        ),
        obs=y
    )
    # y.shape == (n, g)
    return y

def conditioned_model(
    model,
    model_params={},
    data={},
    dtype=torch.float32,
    device='cpu',
    **kwargs,
):
    model_params = {
        k: torch.tensor(v, dtype=dtype, device=device)
        for k, v in model_params.items()
    }
    data = {
        k: torch.tensor(v, dtype=dtype, device=device)
        for k, v in data.items()
    }
    return partial(
        pyro.condition(
            model,
            data=data
        ),
        **model_params,
        **kwargs,
    )

def find_map(
    model,
    max_iter=int(1e5),
    learning_rate = 1e-0,
):
    guide = pyro.infer.autoguide.AutoLaplaceApproximation(model)
    svi = pyro.infer.SVI(
        model,
        guide,
        pyro.optim.Adamax(
            optim_args={"lr": learning_rate},
            clip_args={"clip_norm": 100.}
        ),
        loss=pyro.infer.JitTrace_ELBO()
    )
    
    pyro.clear_param_store()
    pbar = tqdm(range(max_iter))
    history = []
    try:
        for i in pbar:
            elbo = svi.step()

            if np.isnan(elbo):
                break

            # Fit tracking
            history.append(elbo)

            # Reporting/Breaking
            if (i % 1 == 0):
                if i > 1:
                    pbar.set_postfix({
                        'ELBO': history[-1],
                        'delta': history[-2] - history[-1]
                    })
    except KeyboardInterrupt:
        info('Optimization interrupted')
    pbar.refresh()
    
    # Gather MAP from parameter-store
    mapest = {
        k: v.detach().numpy().squeeze()
        for k, v
        in pyro.infer.Predictive(
            model, guide=guide, num_samples=1
        )().items()
    }
    return mapest, np.array(history)

In [None]:
data = xr.open_dataarray('data/core/102506/gtpro.nc').squeeze().sum('read')
data.sizes

In [None]:
minor_allele_incid = (data > 0).sum('library_id').min('allele')

thresh = 1000

plt.hist(minor_allele_incid, bins=100)
plt.axvline(thresh, lw=1, linestyle='--', c='k')

informative_positions = idxwhere(minor_allele_incid.to_series() > thresh)

print(len(informative_positions))

In [None]:
np.random.seed(1)

# Samples with >25% of positions covered
suff_cvrg_samples = (data.sel(position=informative_positions).sum(['allele']) > 0).mean('position') > 0.25
npos = 1000
npos_out = 1000
position_ss_ = np.random.choice(
    informative_positions,
    size=npos + npos_out,
    replace=False
    )
position_ss, position_ss_out = position_ss_[:npos], position_ss_[npos:]

In [None]:
# Build m, y matrices from data.
_data = data.sel(library_id=suff_cvrg_samples, position=position_ss)
m = _data.sum('allele').values
n, g = m
y_obs = _data.sel(allele='alt')

s = 3000
model_fit = conditioned_model(
    model,
    model_params=dict(m=m.values, gamma_hyper=1e-2),
    data=dict(
        alpha=np.ones(n) * 100,
        epsilon_hyper=0.01,
        pi_hyper=1e-1 / s,
        rho_hyper=1e0,
        y=y_obs.values,
    ),
    s=s,
    dtype=torch.float32,
    device='cpu',
)

trace = pyro.poutine.trace(model_fit).get_trace()
trace.compute_log_prob()
print(trace.format_shapes())

In [None]:
mapest, history = find_map(model_fit, learning_rate=5e-1, max_iter=int(1e4))

In [None]:
plot_loss_history(history)

In [None]:
pi_fit = pd.DataFrame(mapest['pi'], index=_data.library_id)
gamma_fit = pd.DataFrame(mapest['gamma'], columns=_data.position)

In [None]:
plt.plot(pi_fit.max(1).sort_values(ascending=False).values)
plt.axhline(1.0, c='k', lw=1, linestyle='--')

In [None]:
plt.plot(pi_fit.max(0).sort_values(ascending=False).values)
plt.axhline(1.0, c='k', lw=1, linestyle='--')

In [None]:
plt.plot(pi_fit.sum(0).sort_values(ascending=False).values)
plt.plot((pi_fit > 0.15).sum(0).sort_values(ascending=False).values)

In [None]:
plt.hist(mapest['alpha'], bins=100)
None

In [None]:
plt.hist(mapest['epsilon'], bins=50)
None

In [None]:
plt.scatter((pi_fit.T * m.mean(1)).sum(1), binary_entropy(gamma_fit).mean(1), s=1)
plt.ylabel('strain-entropy')
plt.xlabel('estimated-total-coverage')
plt.xlim(-1, 10)

In [None]:
major_allele_rcvrg = (data.max('allele') / data.sum('allele')).fillna(0)
per_sample_major_allele_mean_coverage = (data.max('allele') / data.sum('allele')).mean('library_id')

In [None]:
plt.scatter(minor_allele_incid, per_sample_major_allele_mean_coverage, s=1)

In [None]:
bins = np.concatenate([[0], np.linspace(0.5, 1, num=21)])
allele_frac_hist = major_allele_rcvrg.to_pandas().T.apply(lambda x: np.histogram(x, bins=bins)[0]).set_index(bins[:-1]).rename_axis(index='bin_low')

# sns.clustermap(
#     allele_frac_hist**(1/5),
#     metric='cosine',
#     vmin=0, vmax=7,
#     row_cluster=False,
#     figsize=(20, 10)
# )

In [None]:
low_diversity_samples = idxwhere(pi_fit.max(1).sort_values() > 0.98)
high_diversity_samples = idxwhere(pi_fit.max(1).sort_values() < 0.5)

len(low_diversity_samples), len(high_diversity_samples)

In [None]:
sns.clustermap(
    allele_frac_hist[low_diversity_samples]**(1/5),
    metric='cosine',
    vmin=0, vmax=7,
    row_cluster=False,
    figsize=(20, 10)
)

In [None]:
sns.clustermap(
    allele_frac_hist[high_diversity_samples]**(1/5),
    metric='cosine',
    vmin=0, vmax=7,
    row_cluster=False,
    figsize=(20, 10)
)

In [None]:
fig, axs = plt.subplots(2, sharey=True, sharex=True)

axs[1].set_yscale('log')


for library_id in high_diversity_samples[:100]:
    d = data.sel(library_id=library_id)
    d = (d / d.sum('allele')).dropna('position').max('allele')
    axs[0].hist(d, bins=np.linspace(0.5, 0.999, num=11), density=False, alpha=0.005, color='black')
    
for library_id in low_diversity_samples[:100]:
    d = data.sel(library_id=library_id)
    d = (d / d.sum('allele')).dropna('position').max('allele')
    axs[1].hist(d, bins=np.linspace(0.5, 0.999, num=11), density=False, alpha=0.005, color='black')
    


In [None]:
frac_obs = y_obs.numpy() / m.numpy()
frac_obs_ = frac_obs.copy()
frac_obs_[np.isnan(frac_obs_)] = 0.5

frac_expect = (mapest['p_noerr'].squeeze()) #* m.numpy()

print(np.abs(((frac_obs_ - frac_expect) * m.numpy())).sum().sum() / m.numpy().sum())

#fig = plt.figure(figsize=(10, 10))
#sns.heatmap(frac_obs[:,:], cmap='coolwarm', cbar=False, vmin=0, vmax=1)

In [None]:
drop_taxa = (pi_fit.max(0) < 0.01)
drop_taxa.sum()

In [None]:
sns.heatmap(gamma_fit.loc[drop_taxa].T, vmin=0, vmax=1, cmap='coolwarm')

In [None]:
# Build m, y matrices from data, summing over both reads.
_data = data[high_cvrg_samples, :].astype('float32')
m = torch.tensor(_data.sum(['read', 'allele']).values)
n, g = m.shape
y_obs = torch.tensor(_data.sum('read').sel(allele='alt').values)


# Build fully-conditioned model.
s = 1500
model_geno = partial(
    pyro.condition(
        model,
        data={
#           'alpha_hyper': torch.tensor(300.),
          'alpha': torch.ones(n) * 10.,
          'epsilon_hyper': torch.tensor(0.01),
#           'pi_hyper': torch.tensor(1e-1 / s),
#           'rho_hyper': torch.tensor(1e0),
#           'epsilon': torch.ones(n) * 0.001,
#           'rho': torch.ones(s) / s,
           'pi': torch.tensor(mapest['pi']),
           'y': y_obs,
        }
    ),
    s=s,
    m=m,
    gamma_hyper=torch.tensor(1e-0),
#     pi0=torch.tensor(1e-1),
#    rho0=torch.tensor(1e-1),
#    alpha0=torch.tensor(100.),  # These two params have no effect IF we condition
#    epsilon0=torch.tensor(0.01),  #  on epsilon_hyper and alpha_hyper
)

# trace = pyro.poutine.trace(model_fit).get_trace()
# trace.compute_log_prob()
# print(trace.format_shapes())

In [None]:
mapest_geno, history_geno = find_map(model_geno)

In [None]:
plot_loss_history(history_geno)

In [None]:
gamma_geno = pd.DataFrame(mapest_geno['gamma'], columns=_data.position) 

In [None]:
sns.heatmap(gamma_geno.loc[~drop_taxa].T, vmin=0, vmax=1, cmap='coolwarm')

In [None]:
sample_h = binary_entropy(pi_fit).sum(1)
strain_h = binary_entropy(gamma_geno).mean(1)

In [None]:
plt.scatter((pi_fit.T * m.mean(1)).sum(1), strain_h, s=1)
plt.ylabel('strain-entropy')
plt.xlabel('estimated-total-coverage')
#plt.xlim(-1, 10)

In [None]:
plt.scatter(m.mean(1), sample_h, s=1)
plt.ylabel('sample-entropy')
plt.xlabel('sample-mean-coverage')
plt.yscale('log')
plt.xscale('log')

In [None]:
plt.hist(sample_h, bins=np.linspace(0, 10, num=50))
None

In [None]:
plt.hist(strain_h, bins=np.linspace(0, 1, num=50))
None

In [None]:
# Build m, y matrices from data, summing over both reads.
_data = data[:, :].astype('float32')
m = torch.tensor(_data.sum(['read', 'allele']).values)
n, g = m.shape
y_obs = torch.tensor(_data.sum('read').sel(allele='alt').values)


# Build fully-conditioned model.
s = 1500
model_frac = partial(
    pyro.condition(
        model,
        data={
#           'alpha_hyper': torch.tensor(300.),
          'alpha': torch.ones(n) * 10.,
          'epsilon_hyper': torch.tensor(0.01),
          'pi_hyper': torch.tensor(1e-1 / s),
          'rho_hyper': torch.tensor(1e0),
#           'epsilon': torch.ones(n) * 0.001,
#           'rho': torch.ones(s) / s,
           'gamma': torch.tensor(mapest_geno['gamma']),
           'y': y_obs,
        }
    ),
    s=s,
    m=m,
#     gamma_hyper=torch.tensor(1e-0),
#     pi0=torch.tensor(1e-1),
#    rho0=torch.tensor(1e-1),
#    alpha0=torch.tensor(100.),  # These two params have no effect IF we condition
#    epsilon0=torch.tensor(0.01),  #  on epsilon_hyper and alpha_hyper
)

# trace = pyro.poutine.trace(model_fit).get_trace()
# trace.compute_log_prob()
# print(trace.format_shapes())

In [None]:
mapest_frac, history_frac = find_map(model_frac)