In [None]:
# %%shell

# pip install pyro-ppl arviz

# git init .
# git remote add origin https://github.com/bsmith89/gtpro-strain-factorization
# git fetch origin
# git checkout main

# curl -L -o gtpro.nc https://www.dropbox.com/s/3pv7oszorvhbtee/gtpro.nc?dl=1

In [None]:
import pandas as pd
from lib.util import info, idxwhere
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import scipy as sp

import pyro
import pyro.distributions as dist
import torch
from functools import partial
import arviz as az
from pyro.ops.contract import einsum
import seaborn as sns
from tqdm import tqdm
import xarray as xr

def rss(x, y):
    return np.sqrt(np.sum((x - y)**2))

def binary_entropy(p):
    q = (1 - p)
    return -p * np.log2(p) - q * np.log2(q)

def plot_loss_history(loss_history):
    min_loss = loss_history.min()
    plt.plot(loss_history - min_loss)
    plt.title(f'+{min_loss}')
    plt.yscale('log')
    return plt.gca()

In [None]:
def find_map(
    model,
    max_iter=int(1e5),
    learning_rate = 1e-0,
):
    guide = pyro.infer.autoguide.AutoLaplaceApproximation(model)
    svi = pyro.infer.SVI(
        model,
        guide,
        pyro.optim.Adamax(
            optim_args={"lr": learning_rate},
            clip_args={"clip_norm": 100.}
        ),
        loss=pyro.infer.JitTrace_ELBO()
    )
    
    pyro.clear_param_store()
    pbar = tqdm(range(max_iter))
    history = []
    try:
        for i in pbar:
            elbo = svi.step()

            if np.isnan(elbo):
                break

            # Fit tracking
            history.append(elbo)

            # Reporting/Breaking
            if (i % 1 == 0):
                if i > 1:
                    pbar.set_postfix({
                        'ELBO': history[-1],
                        'delta': history[-2] - history[-1]
                    })
    except KeyboardInterrupt:
        info('Optimization interrupted')
    pbar.refresh()
    
    # Gather MAP from parameter-store
    mapest = {
        k: v.detach().cpu().numpy().squeeze()
        for k, v
        in pyro.infer.Predictive(
            model, guide=guide, num_samples=1
        )().items()
    }
    return mapest, np.array(history)


def mean_residual_count(expect_frac, obs_count, m):
    frac_obs = obs_count / m
    out = np.abs(((frac_obs - expect_frac)))
    out[np.isnan(out)] = 0
    return (out * m).sum() / m.sum()

In [None]:
def model_cpu(
    s,
    m,
    y=None,
    gamma_hyper=torch.tensor(1.),
    pi0=torch.tensor(1.),
    rho0=torch.tensor(1.),
    epsilon0=torch.tensor(0.01),
    alpha0=torch.tensor(1000.),
):
    
    n, g = m.shape
    
    with pyro.plate('position', g, dim=-1):
        with pyro.plate('strain', s, dim=-2):
            gamma = pyro.sample(
                'gamma', dist.Beta(gamma_hyper, gamma_hyper)
            )
    # gamma.shape == (s, g)
    
    rho_hyper = pyro.sample('rho_hyper', dist.Gamma(rho0, 1.))
    rho = pyro.sample('rho', dist.Dirichlet(torch.ones(s) * rho_hyper))
#     rho_hyper = pyro.sample('rho_hyper', dist.Beta(rho0, 1 - rho0))
#     rho = pyro.sample('rho', dist.Beta(rho_hyper, 1 - rho_hyper).expand([1, s]).to_event(1))
    # rho.shape == 's'
    
    epsilon_hyper = pyro.sample('epsilon_hyper', dist.Beta(1., 1 / epsilon0))
    alpha_hyper = pyro.sample('alpha_hyper', dist.Gamma(alpha0, 1.))
    pi_hyper = pyro.sample('pi_hyper', dist.Gamma(pi0, 1.))
    
    with pyro.plate('sample', n, dim=-1):
        pi = pyro.sample('pi', dist.Dirichlet(rho * s * pi_hyper))
        alpha = pyro.sample('alpha', dist.Gamma(alpha_hyper, 1.)).unsqueeze(-1)
        epsilon = pyro.sample('epsilon', dist.Beta(1., 1 / epsilon_hyper)).unsqueeze(-1) 
    # pi.shape == (n, s)
    # alpha.shape == epsilon.shape == (n,)

    p_noerr = pyro.deterministic('p_noerr', (pi @ gamma))
    p = pyro.deterministic('p',
        (1 - epsilon / 2) * (p_noerr) +
        (epsilon / 2) * (1 - p_noerr)
    )
    # p.shape == (n, g)

        
    y = pyro.sample(
        'y',
        dist.BetaBinomial(
            concentration1=alpha * p,
            concentration0=alpha * (1 - p),
            total_count=m
        ),
        obs=y
    )
    # y.shape == (n, g)
    return y

In [None]:
# def model_gpu(
#     s,
#     m,
#     y=None,
#     gamma_hyper=torch.tensor(1.).cuda(),
#     pi0=torch.tensor(1.).cuda(),
#     rho0=torch.tensor(1.).cuda(),
#     epsilon0=torch.tensor(0.01).cuda(),
#     alpha0=torch.tensor(1000.).cuda(),
# ):
    
#     n, g = m.shape
    
#     with pyro.plate('position', g, dim=-1):
#         with pyro.plate('strain', s, dim=-2):
#             gamma = pyro.sample(
#                 'gamma', dist.Beta(gamma_hyper, gamma_hyper)
#             )
#     # gamma.shape == (s, g)
    
#     rho_hyper = pyro.sample('rho_hyper', dist.Gamma(rho0, 1.))
#     rho = pyro.sample('rho', dist.Dirichlet(torch.ones(s).cuda() * rho_hyper))
# #     rho_hyper = pyro.sample('rho_hyper', dist.Beta(rho0, 1 - rho0))
# #     rho = pyro.sample('rho', dist.Beta(rho_hyper, 1 - rho_hyper).expand([1, s]).to_event(1))
#     # rho.shape == 's'
    
#     epsilon_hyper = pyro.sample('epsilon_hyper', dist.Beta(1., 1 / epsilon0))
#     alpha_hyper = pyro.sample('alpha_hyper', dist.Gamma(alpha0, 1.))
#     pi_hyper = pyro.sample('pi_hyper', dist.Gamma(pi0, 1.))
    
#     with pyro.plate('sample', n, dim=-1):
#         pi = pyro.sample('pi', dist.Dirichlet(rho * s * pi_hyper))
#         alpha = pyro.sample('alpha', dist.Gamma(alpha_hyper, 1.)).unsqueeze(-1)
#         epsilon = pyro.sample('epsilon', dist.Beta(1., 1 / epsilon_hyper)).unsqueeze(-1) 
#     # pi.shape == (n, s)
#     # alpha.shape == epsilon.shape == (n,)

#     p_noerr = pyro.deterministic('p_noerr', (pi @ gamma))
#     p = pyro.deterministic('p',
#         (1 - epsilon / 2) * (p_noerr) +
#         (epsilon / 2) * (1 - p_noerr)
#     )
#     # p.shape == (n, g)

        
#     y = pyro.sample(
#         'y',
#         dist.BetaBinomial(
#             concentration1=alpha * p,
#             concentration0=alpha * (1 - p),
#             total_count=m
#         ),
#         obs=y
#     )
#     # y.shape == (n, g)
#     return y

In [None]:
data = xr.open_dataarray('gtpro.nc').squeeze()
data.sizes

In [None]:
major_allele_rcvrg = (data.sum('read').max('allele') / data.sum(['read', 'allele'])).fillna(0)

In [None]:
bins = np.concatenate([[0], np.linspace(0.5, 1, num=21)])
allele_frac_hist = major_allele_rcvrg.to_pandas().T.apply(lambda x: np.histogram(x, bins=bins)[0]).set_index(bins[:-1]).rename_axis(index='bin_low')

In [None]:
np.random.seed(1)

# Samples with >5% of positions covered
suff_cvrg_samples = (data.sum(['allele', 'read']) > 0).mean('position') > 0.05
npos = 4000
npos_out = 4000
position_ss_ = np.random.choice(
    np.arange(data.shape[1]),
    size=npos + npos_out,
    replace=False
    )
position_ss, position_ss_out = position_ss_[:npos], position_ss_[npos:]

# Build m, y matrices from data, summing over both reads.
_data = data[suff_cvrg_samples, position_ss].astype('float32')
m = _data.sum(['read', 'allele']).values
n, g = m.shape
y_obs = _data.sum('read').sel(allele='alt').values


# Build fully-conditioned model.
s = 3000
model_fit = partial(
    pyro.condition(
        model_gpu,
        data={
          # 'alpha_hyper': torch.tensor(100.).cuda(),
          'alpha': torch.ones(n).cuda() * 100.,
          'epsilon_hyper': torch.tensor(0.01).cuda(),
          'pi_hyper': torch.tensor(1e-1 / s).cuda(),
          'rho_hyper': torch.tensor(1e0).cuda(),
#           'epsilon': torch.ones(n).cuda() * 0.001,
#           'rho': torch.ones(s).cuda() / s,
           'y': torch.tensor(y_obs).cuda(),
        }
    ),
    s=s,
    m=torch.tensor(m).cuda(),
    gamma_hyper=torch.tensor(1e-2).cuda(),
#     pi0=torch.tensor(1e-1).cuda(),
#    rho0=torch.tensor(1e-1).cuda(),
#    alpha0=torch.tensor(100.).cuda(),  # These two params have no effect IF we condition
#    epsilon0=torch.tensor(0.01).cuda(),  #  on epsilon_hyper and alpha_hyper
)

# trace = pyro.poutine.trace(model_fit).get_trace()
# trace.compute_log_prob()
# print(trace.format_shapes())

In [None]:
import pickle

with open('test.pickle', 'rb') as f:
    mapest = pickle.load(f)

In [None]:
pi_fit = pd.DataFrame(mapest['pi'], index=_data.library_id)
gamma_fit = pd.DataFrame(mapest['gamma'], columns=_data.position)

In [None]:
plt.plot(pi_fit.max(1).sort_values(ascending=False).values)
plt.axhline(1.0, c='k', lw=1, linestyle='--')

In [None]:
plt.plot(pi_fit.max(0).sort_values(ascending=False).values)
plt.axhline(1.0, c='k', lw=1, linestyle='--')

In [None]:
plt.plot(pi_fit.sum(0).sort_values(ascending=False).values)
plt.plot((pi_fit > 0.15).sum(0).sort_values(ascending=False).values)

In [None]:
plt.hist(mapest['alpha'], bins=100)
None

In [None]:
plt.hist(mapest['epsilon'], bins=50)
None

In [None]:
plt.scatter((pi_fit.T * m.mean(1)).sum(1), binary_entropy(gamma_fit).mean(1), s=1)
plt.ylabel('strain-entropy')
plt.xlabel('estimated-total-coverage')
plt.xlim(-1, 10)

In [None]:
low_diversity_samples = idxwhere(pi_fit.max(1).sort_values() > 0.99)
high_diversity_samples = idxwhere(pi_fit.max(1).sort_values() < 0.75)

len(low_diversity_samples), len(high_diversity_samples)

In [None]:
for library_id in low_diversity_samples[:3]:
    d = data.sel(library_id=library_id).sum(['read'])
    d = (d / d.sum('allele')).dropna('position').max('allele')
    plt.hist(d, bins=np.linspace(0.5, 0.9999, num=21), density=True, alpha=0.2)
    plt.yscale('log')

In [None]:
high_depth_variable_positions = ((major_allele_rcvrg > 0.0) & (major_allele_rcvrg < 1.0) & (depth > 2))
plt.hist(high_depth_variable_positions.sum('position'))
plt.yscale('log')

In [None]:
high_depth_variable_positions.sum('position').sel(library_id=low_diversity_samples)

In [None]:
high_depth_variable_positions.sum('position').sel(library_id=low_diversity_samples)

In [None]:
sns.clustermap(
    allele_frac_hist[low_diversity_samples]**(1/5),
    metric='cosine',
    vmin=0, vmax=7,
    row_cluster=False,
    figsize=(len(low_diversity_samples)*0.004, 10)
)

In [None]:
sns.clustermap(
    allele_frac_hist[high_diversity_samples]**(1/5),
    metric='cosine',
    vmin=0, vmax=7,
    row_cluster=False,
    figsize=(len(high_diversity_samples)*0.004, 10)
)

In [None]:
frac_obs = y_obs.cpu().numpy() / m_
frac_obs_ = frac_obs.copy()
frac_obs_[np.isnan(frac_obs_)] = 0.5

frac_expect = (mapest['p_noerr'].squeeze()) #* m.numpy()

print(np.abs(((frac_obs_ - frac_expect) * m_)).sum().sum() / m_.sum())

#fig = plt.figure(figsize=(10, 10))
#sns.heatmap(frac_obs[:,:], cmap='coolwarm', cbar=False, vmin=0, vmax=1)

In [None]:
drop_taxa = (pi_fit.max(0) < 0.01)
drop_taxa.sum()

In [None]:
sns.heatmap(gamma_fit.loc[drop_taxa].T, vmin=0, vmax=1, cmap='coolwarm')

In [None]:
# Build m, y matrices from data, summing over both reads.
_data = data[suff_cvrg_samples, position_ss_out].astype('float32')
m = _data.sum(['read', 'allele']).values
n, g = m.shape
y_obs = _data.sum('read').sel(allele='alt').values


# Build fully-conditioned model.
s = 3000
model_out = partial(
    pyro.condition(
        model_gpu,
        data={
          # 'alpha_hyper': torch.tensor(100.).cuda(),
          'alpha': torch.ones(n).cuda() * 100.,
          'epsilon_hyper': torch.tensor(0.01).cuda(),
          # 'pi_hyper': torch.tensor(1e-1 / s).cuda(),
          # 'rho_hyper': torch.tensor(1e0).cuda(),
#           'epsilon': torch.ones(n).cuda() * 0.001,
#           'rho': torch.ones(s).cuda() / s,
           'y': torch.tensor(y_obs).cuda(),

           'pi': torch.tensor(mapest['pi']).cuda(),
        }
    ),
    s=s,
    m=torch.tensor(m).cuda(),
    gamma_hyper=torch.tensor(1e-2).cuda(),
#     pi0=torch.tensor(1e-1).cuda(),
#    rho0=torch.tensor(1e-1).cuda(),
#    alpha0=torch.tensor(100.).cuda(),  # These two params have no effect IF we condition
#    epsilon0=torch.tensor(0.01).cuda(),  #  on epsilon_hyper and alpha_hyper
)

# trace = pyro.poutine.trace(model_fit).get_trace()
# trace.compute_log_prob()
# print(trace.format_shapes())

In [None]:
mapest_out, history = find_map(model_out, learning_rate=1e-0)

In [None]:
_data = data[suff_cvrg_samples, position_ss].astype('float32')
m = _data.sum(['read', 'allele']).values
n, g = m.shape
y_obs = _data.sum('read').sel(allele='alt').values
frac_expect = (mapest['p_noerr'].squeeze()) #* m.numpy()

def residual_count(expect_frac, obs_count, m):
    frac_obs = obs_count / m
    frac_obs[np.isnan(frac_obs)] = 0.5
    return np.abs(((frac_obs_ - expect_frac) * m))

residual_count(frac_expect, y_obs, m).shape