## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import scipy as sp
import pyro
import pyro.distributions as dist
import torch
from functools import partial
from tqdm import tqdm
import xarray as xr
import warnings
from torch.jit import TracerWarning

In [None]:
warnings.filterwarnings(
    "ignore",
    message="torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.",
    category=TracerWarning,
#     module="trace_elbo",  # FIXME: What is the correct regex for module?
#     lineno=5,
)

In [None]:
# logging.py

from datetime import datetime
import sys

def info(*msg):
    now = datetime.now()
    print(f'[{now}]', *msg, file=sys.stderr, flush=True)

In [None]:
# torch.py

def as_torch(x, dtype=None, device=None):
    # Cast inputs and set device
    return torch.tensor(x, dtype=dtype, device=device)

def all_torch(dtype=None, device=None, **kwargs):
    # Cast inputs and set device
    return {k: as_torch(kwargs[k], dtype=dtype, device=device) for k in kwargs}

# pyro.py

def shape_info(model, *args, **kwargs):
    _trace = pyro.poutine.trace(model).get_trace(*args, **kwargs)
    _trace.compute_log_prob()
    info(_trace.format_shapes())

## Model Specification

In [None]:
# model.py

import pyro
import pyro.distributions as dist
import torch


def NegativeBinomialReparam(mu, r, eps=1e-5):
    p = torch.clamp(1. / ((r / mu) + 1.), min=eps, max=1. - eps)
    return dist.NegativeBinomial(
        total_count=r,
        probs=p
    )


def model_binomial(
    n,
    g,
    s,
    gamma_hyper=1.,
    rho_hyper=1.,
    pi_hyper=1.,
    epsilon_hyper_hyper=0.01,
    mu_hyper_mean=1.,
    mu_hyper_scale=1.,
    m_hyper_r=1.,
    dtype=torch.float32,
    device='cpu',
):
    
    gamma_hyper, rho_hyper, pi_hyper, epsilon_hyper_hyper, mu_hyper_mean, mu_hyper_scale, m_hyper_r = (
        as_torch(x, dtype=dtype, device=device)
        for x in [gamma_hyper, rho_hyper, pi_hyper, epsilon_hyper_hyper, mu_hyper_mean, mu_hyper_scale, m_hyper_r]
    )

    # Genotypes
    with pyro.plate('position', g, dim=-1):
        with pyro.plate('strain', s, dim=-2):
            gamma = pyro.sample(
                'gamma', dist.RelaxedBernoulli(temperature=gamma_hyper, logits=torch.zeros((1,), dtype=dtype, device=device).squeeze())
            )
    
    # Meta-community composition
    rho = pyro.sample('rho', dist.RelaxedOneHotCategorical(temperature=rho_hyper, logits=torch.zeros(s, dtype=dtype, device=device)))

    epsilon_hyper = pyro.sample('epsilon_hyper', dist.Beta(1., 1 / epsilon_hyper_hyper))
    with pyro.plate('sample', n, dim=-1):
        # Community composition
        pi = pyro.sample('pi', dist.RelaxedOneHotCategorical(temperature=pi_hyper, probs=rho))
        # Sample coverage
        mu = pyro.sample('mu', dist.LogNormal(loc=torch.log(mu_hyper_mean), scale=mu_hyper_scale))
        # Sequencing error
        epsilon = pyro.sample('epsilon', dist.Beta(1., 1 / epsilon_hyper)).unsqueeze(-1)
        
    # Sample/position coverage
    m = pyro.sample('m', NegativeBinomialReparam(mu.reshape((-1, 1)), m_hyper_r).expand([n, g]).to_event())
    
    # Error model
    p_noerr = pyro.deterministic('p_noerr', pi @ gamma)
    p = pyro.deterministic('p',
        (1 - epsilon / 2) * (p_noerr) +
        (epsilon / 2) * (1 - p_noerr)
    )
    
    # Observation
    y = pyro.sample(
        'y',
        dist.Binomial(
            probs=p,
            total_count=m
        ).to_event(),
    )
    
def model_betabinomial(
    n,
    g,
    s,
    gamma_hyper=1.,
    rho_hyper=1.,
    pi_hyper=1.,
    alpha_hyper_hyper_mean=100.,
    alpha_hyper_hyper_scale=10.,
    alpha_hyper_scale=0.5,
    epsilon_hyper_hyper=0.01,
    mu_hyper_mean=1.,
    mu_hyper_scale=1.,
    m_hyper_r=1.,
    dtype=torch.float32,
    device='cpu',
):
    
    gamma_hyper, rho_hyper, pi_hyper, alpha_hyper_hyper_mean, alpha_hyper_hyper_scale, alpha_hyper_scale, epsilon_hyper_hyper, mu_hyper_mean, mu_hyper_scale, m_hyper_r = (
        as_torch(x, dtype=dtype, device=device)
        for x in [
            gamma_hyper,
            rho_hyper,
            pi_hyper,
            alpha_hyper_hyper_mean,
            alpha_hyper_hyper_scale,
            alpha_hyper_scale,
            epsilon_hyper_hyper,
            mu_hyper_mean,
            mu_hyper_scale,
            m_hyper_r,
        ]
    )

    # Genotypes
    with pyro.plate('position', g, dim=-1):
        with pyro.plate('strain', s, dim=-2):
            gamma = pyro.sample(
                'gamma', dist.RelaxedBernoulli(temperature=gamma_hyper, logits=torch.zeros((1,), dtype=dtype, device=device).squeeze())
            )
    
    # Meta-community composition
    rho = pyro.sample('rho', dist.RelaxedOneHotCategorical(temperature=rho_hyper, logits=torch.zeros(s, dtype=dtype, device=device)))

    alpha_hyper_mean = pyro.sample('alpha_hyper_mean', dist.LogNormal(loc=torch.log(alpha_hyper_hyper_mean), scale=alpha_hyper_hyper_scale))
    epsilon_hyper = pyro.sample('epsilon_hyper', dist.Beta(1., 1 / epsilon_hyper_hyper))
    with pyro.plate('sample', n, dim=-1):
        # Community composition
        pi = pyro.sample('pi', dist.RelaxedOneHotCategorical(temperature=pi_hyper, probs=rho))
        # Sample coverage
        mu = pyro.sample('mu', dist.LogNormal(loc=torch.log(mu_hyper_mean), scale=mu_hyper_scale))
        # Sequencing error
        epsilon = pyro.sample('epsilon', dist.Beta(1., 1 / epsilon_hyper)).unsqueeze(-1)
        alpha = pyro.sample('alpha', dist.LogNormal(loc=torch.log(alpha_hyper_mean), scale=alpha_hyper_scale)).unsqueeze(-1)
        
    # Sample/position coverage
    m = pyro.sample('m', NegativeBinomialReparam(mu.reshape((-1, 1)), m_hyper_r).expand([n, g]).to_event())
    
    # Error model
    p_noerr = pyro.deterministic('p_noerr', pi @ gamma)
    p = pyro.deterministic(
        'p',
        (1 - epsilon / 2) * (p_noerr) +
        (epsilon / 2) * (1 - p_noerr)
    )
    
    # Observation
    y = pyro.sample(
        'y',
        dist.BetaBinomial(
            concentration1=alpha * p,
            concentration0=alpha * (1 - p),
            total_count=m
        ).to_event(),
    )
    
    
def model_binomial_missing(
    n,
    g,
    s,
    gamma_hyper=1.,
    delta_hyper_temp=0.1,
    delta_hyper_p=0.9,
    rho_hyper=1.,
    pi_hyper=1.,
    epsilon_hyper_hyper=0.01,
    mu_hyper_mean=1.,
    mu_hyper_scale=1.,
    m_hyper_r=1.,
    dtype=torch.float32,
    device='cpu',
):
    
    (
        gamma_hyper,
        delta_hyper_temp,
        delta_hyper_p,
        rho_hyper,
        pi_hyper,
        epsilon_hyper_hyper,
        mu_hyper_mean,
        mu_hyper_scale,
        m_hyper_r,
    ) = (
        as_torch(x, dtype=dtype, device=device)
        for x in [
            gamma_hyper,
            delta_hyper_temp,
            delta_hyper_p,
            rho_hyper,
            pi_hyper,
            epsilon_hyper_hyper,
            mu_hyper_mean,
            mu_hyper_scale,
            m_hyper_r,
        ]
    )

    # Genotypes
    with pyro.plate('position', g, dim=-1):
        with pyro.plate('strain', s, dim=-2):
            gamma = pyro.sample(
                'gamma', dist.RelaxedBernoulli(temperature=gamma_hyper, logits=torch.zeros((1,), dtype=dtype, device=device).squeeze())
            )
            # Position presence/absence
            delta = pyro.sample(
                'delta', dist.RelaxedBernoulli(temperature=delta_hyper_temp, probs=delta_hyper_p)
            )
    
    # Meta-community composition
    rho = pyro.sample('rho', dist.RelaxedOneHotCategorical(temperature=rho_hyper, logits=torch.zeros(s, dtype=dtype, device=device)))

    epsilon_hyper = pyro.sample('epsilon_hyper', dist.Beta(1., 1 / epsilon_hyper_hyper))
    with pyro.plate('sample', n, dim=-1):
        # Community composition
        pi = pyro.sample('pi', dist.RelaxedOneHotCategorical(temperature=pi_hyper, probs=rho))
        # Sample coverage
        mu = pyro.sample('mu', dist.LogNormal(loc=torch.log(mu_hyper_mean), scale=mu_hyper_scale))
        # Sequencing error
        epsilon = pyro.sample('epsilon', dist.Beta(1., 1 / epsilon_hyper)).unsqueeze(-1)
        
    # Depth at each position
    nu = pyro.deterministic("nu", pi @ delta)
    m = pyro.sample('m', NegativeBinomialReparam(nu * mu.reshape((-1,1)), m_hyper_r).to_event())
  
    # Expected fractions of each allele at each position
    p_noerr = pyro.deterministic('p_noerr', pi @ (gamma * delta) / nu)
    p = pyro.deterministic('p',
        (1 - epsilon / 2) * (p_noerr) +
        (epsilon / 2) * (1 - p_noerr)
    )

    
    # Observation
    y = pyro.sample(
        'y',
        dist.Binomial(
            probs=p,
            total_count=m
        ).to_event(),
    )
    
def model_betabinomial_missing(
    n,
    g,
    s,
    gamma_hyper=1.,
#     delta_hyper=0.01,
    delta_hyper_temp=0.1,
    delta_hyper_p=0.9,
    rho_hyper=1.,
    pi_hyper=1.,
    alpha_hyper_hyper_mean=100.,
    alpha_hyper_hyper_scale=1.,
    alpha_hyper_scale=0.5,
#     epsilon_hyper_hyper=0.01,
    epsilon_hyper_alpha=1.5,
    epsilon_hyper_beta=1.5 / 0.01,
    mu_hyper_mean=1.,
    mu_hyper_scale=1.,
    m_hyper_r=1.,
    dtype=torch.float32,
    device='cpu',
):
    
    (
        gamma_hyper,
        delta_hyper_temp,
        delta_hyper_p,
        rho_hyper,
        pi_hyper,
        alpha_hyper_hyper_mean,
        alpha_hyper_hyper_scale,
        alpha_hyper_scale,
        epsilon_hyper_alpha,
        epsilon_hyper_alpha,
#         epsilon_hyper_hyper,
        mu_hyper_mean,
        mu_hyper_scale,
        m_hyper_r
    ) = (
        as_torch(x, dtype=dtype, device=device)
        for x in [
            gamma_hyper,
            delta_hyper_temp,
            delta_hyper_p,
            rho_hyper,
            pi_hyper,
            alpha_hyper_hyper_mean,
            alpha_hyper_hyper_scale,
            alpha_hyper_scale,
            epsilon_hyper_alpha,
            epsilon_hyper_alpha,
    #         epsilon_hyper_hyper,
            mu_hyper_mean,
            mu_hyper_scale,
            m_hyper_r
        ]
    )

    # Genotypes
    with pyro.plate('position', g, dim=-1):
        with pyro.plate('strain', s, dim=-2):
            gamma = pyro.sample(
                'gamma', dist.RelaxedBernoulli(temperature=gamma_hyper, logits=torch.zeros((1,), dtype=dtype, device=device).squeeze())
            )
            # Position presence/absence
            delta = pyro.sample(
                'delta', dist.RelaxedBernoulli(temperature=delta_hyper_temp, probs=delta_hyper_p)
            )
    
    # Meta-community composition
    rho = pyro.sample('rho', dist.RelaxedOneHotCategorical(temperature=rho_hyper, logits=torch.zeros(s, dtype=dtype, device=device)))

    alpha_hyper_mean = pyro.sample('alpha_hyper_mean', dist.LogNormal(loc=torch.log(alpha_hyper_hyper_mean), scale=alpha_hyper_hyper_scale))
    alpha_hyper_scale = pyro.sample('alpha_hyper_scale', dist.LogNormal(loc=0, scale=1))
    with pyro.plate('sample', n, dim=-1):
        # Community composition
        pi = pyro.sample('pi', dist.RelaxedOneHotCategorical(temperature=pi_hyper, probs=rho))
        # Sample coverage
        mu = pyro.sample('mu', dist.LogNormal(loc=torch.log(mu_hyper_mean), scale=mu_hyper_scale))
        # Sequencing error
        epsilon = pyro.sample('epsilon', dist.Beta(epsilon_hyper_alpha, epsilon_hyper_beta)).unsqueeze(-1)
        alpha = pyro.sample('alpha', dist.LogNormal(loc=torch.log(alpha_hyper_mean), scale=alpha_hyper_scale)).unsqueeze(-1)
        
    # Depth at each position
    nu = pyro.deterministic("nu", pi @ delta)
    m = pyro.sample('m', NegativeBinomialReparam(nu * mu.reshape((-1,1)), m_hyper_r).to_event())

    # Expected fractions of each allele at each position
    p_noerr = pyro.deterministic('p_noerr', pi @ (gamma * delta) / nu)
    p = pyro.deterministic('p',
        (1 - epsilon / 2) * (p_noerr) +
        (epsilon / 2) * (1 - p_noerr)
    )
    
    # Observation
    y = pyro.sample(
        'y',
        dist.BetaBinomial(
            concentration1=alpha * p,
            concentration0=alpha * (1 - p),
            total_count=m
        ).to_event(),
    )

In [None]:
epsilon_hyper_alpha, epsilon_hyper_beta = 1.5, 1.5 / 0.01
plt.hist(pyro.sample('epsilon_hyper', dist.Beta(epsilon_hyper_alpha, epsilon_hyper_beta).expand([10000])).cpu().numpy(), bins=100)
None

In [None]:
shape_info(model_binomial_missing, n=100, g=200, s=20)

## Simulation

In [None]:
# simulate.py

def condition_model(model, data=None, device='cpu', dtype=torch.float32, **model_kwargs):
    if data is None:
        data = {}
        
    conditioned_model = partial(
        pyro.condition(
            model,
            data=all_torch(**data, dtype=dtype, device=device),
        ),
        **model_kwargs,
        dtype=dtype,
        device=device,
    )
    return conditioned_model
    
def simulate(model):
    obs = pyro.infer.Predictive(model, num_samples=1)()
    obs = {
        k: obs[k].detach().cpu().numpy().squeeze()
        for k in obs.keys()
    }
    return obs

### SimShape-1: Small study

In [None]:
seed = 1
pyro.util.set_rng_seed(seed)

n_sim = 100
g_sim = 5000
s_sim = 20

sim1 = simulate(
    condition_model(
        model_betabinomial_missing,
        data=dict(
            alpha_hyper_mean=100.
        ),
        n=n_sim,
        g=g_sim,
        s=s_sim,
        gamma_hyper=0.01,
        delta_hyper_temp=0.01,
        delta_hyper_p=0.7,
        pi_hyper=0.5,
        rho_hyper=2.,
        mu_hyper_mean=2.,
        mu_hyper_scale=0.5,
        m_hyper_r=10.,
        alpha_hyper_scale=0.5,
        epsilon_hyper_alpha=1.5,
        epsilon_hyper_beta=1.5/0.01,
        device='cpu'
    )
)

## Visualization

In [None]:
# genotype.py

from scipy.spatial.distance import pdist, squareform
import scipy as sp

def prob_to_sign(gamma):
    return gamma * 2 - 1

# TODO: Demonstrate that the "genotype_distance" defined above, degrades well under decreasing coverage and increasing missingness.
def genotype_distance(x, y):
    x = prob_to_sign(x)
    y = prob_to_sign(y)
    dist = ((x - y) / 2) ** 2
    weight = (x * y) ** 2
    wmean_dist = ((weight * dist).mean()) / ((weight.mean()))
    return wmean_dist

# TODO: Demonstrate that the "genotype_distance" defined above, degrades well under decreasing coverage and increasing missingness.
def sign_genotype_distance(x, y):
    dist = ((x - y) / 2) ** 2
    weight = (x * y) ** 2
    wmean_dist = ((weight * dist).mean()) / ((weight.mean()))
    return wmean_dist

def genotype_pdist(gamma):
    return pdist(gamma, metric=genotype_distance)

def genotype_pdist2(gamma, progress=False):
    metric = sign_genotype_distance
    X = np.asarray(gamma * 2 - 1)
    m = X.shape[0]
    dm = np.empty((m * (m - 1)) // 2)
    k = 0
    with tqdm(total=len(dm), disable=(not progress)) as pbar:
        for i in range(0, m - 1):
            for j in range(i + 1, m):
                dm[k] = metric(X[i], X[j])
                k = k + 1
                pbar.update()
    return dm
    

# TODO: Try out cosine instead of 'genotype_distance'
# def genotype_pdist(gamma):
#     return pdist(gamma * 2 - 1, metric='cosine')

def counts_to_p_estimate(y, m, pseudo=1):
    return (y + pseudo) / (m + pseudo * 2)

def genotype_linkage(gamma, progress=False, **kwargs):
    dmat = genotype_pdist2(gamma, progress=progress)
    kw = dict(method='complete')
    kw.update(kwargs)
    return sp.cluster.hierarchy.linkage(dmat, **kw), dmat

In [None]:
# plot.py

import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib as mpl

def plot_genotype(gamma, linkage_kw=None, **kwargs):
    if linkage_kw is None:
        linkage_kw = {}
    linkage, _ = genotype_linkage(gamma, **linkage_kw)
    
    gamma_t = gamma.T
    ny, nx = gamma_t.shape
    
    mwidth = nx * 0.15
    mheight = ny * 0.02
    dwidth = 0.2
    dheight = 1.0
    fwidth = mwidth + dwidth
    fheight = mheight + dheight
    dendrogram_ratio = (dwidth / fwidth, dheight / fheight)
    
    kw = dict(
        vmin=-1,
        vmax=1,
        cmap='coolwarm',
        dendrogram_ratio=dendrogram_ratio,
        col_linkage=linkage,
        figsize=(fwidth, fheight),
        xticklabels=1,
        yticklabels=0,
    )
    kw.update(kwargs)
    sns.clustermap(prob_to_sign(gamma_t), **kw)
    
def plot_missing(delta, **kwargs):
    delta_t = delta.T
    ny, nx = delta_t.shape
    
    mwidth = nx * 0.15
    mheight = ny * 0.02
    dwidth = 0.2
    dheight = 1.0
    fwidth = mwidth + dwidth
    fheight = mheight + dheight
    dendrogram_ratio = (dwidth / fwidth, dheight / fheight)
    
    kw = dict(
        vmin=0, vmax=1, dendrogram_ratio=dendrogram_ratio, figsize=(fwidth, fheight), xticklabels=1, yticklabels=0,
    )
    kw.update(kwargs)
    sns.clustermap(delta_t, **kw)
    
def plot_community(pi, **kwargs):
    ny, nx = pi.shape
    
    mwidth = nx * 0.2
    mheight = ny * 0.1
    dwidth = 0.2
    dheight = 1.0
    fwidth = mwidth + dwidth
    fheight = mheight + dheight
    dendrogram_ratio = (dwidth / fwidth, dheight / fheight)
    
    kw = dict(
        metric='cosine', vmin=0, vmax=1, dendrogram_ratio=dendrogram_ratio, figsize=(fwidth, fheight), xticklabels=1,
    )
    kw.update(kwargs)
    sns.clustermap(pi, **kw)
    
def plot_genotype_similarity(gamma, linkage_kw=None, **kwargs):
    if linkage_kw is None:
        linkage_kw = {}
    linkage, dmat = genotype_linkage(gamma, **linkage_kw)
    dmat = squareform(dmat)
    
    nx = ny = gamma.shape[0]
    
    mwidth = nx * 0.15
    mheight = ny * 0.15
    dwidth = 0.5
    dheight = 0.5
    fwidth = mwidth + dwidth
    fheight = mheight + dheight
    dendrogram_ratio = (dwidth / fwidth, dheight / fheight)
    
    kw = dict(
        vmin=0, vmax=1, dendrogram_ratio=dendrogram_ratio, row_linkage=linkage, col_linkage=linkage, figsize=(fwidth, fheight), xticklabels=1, yticklabels=1,
    )
    kw.update(kwargs)
    sns.clustermap(1 - dmat, **kw)

In [None]:
n_plt = 100
g_plt = 200
s_plt = 20

In [None]:
plot_genotype(counts_to_p_estimate(sim1['y'][:n_plt, :g_plt], sim1['m'][:n_plt, :g_plt]), linkage_kw=dict(progress=True))

In [None]:
plot_community(sim1['pi'][:s_plt, :n_plt])

In [None]:
plot_genotype_similarity(counts_to_p_estimate(sim1['y'][:n_plt, :g_plt], sim1['m'][:n_plt, :g_plt]), linkage_kw=dict(progress=True))

In [None]:
plot_genotype(sim1['gamma'][:s_plt, :g_plt])

In [None]:
plot_missing(sim1['delta'][:s_plt, :g_plt])

In [None]:
plot_missing(sim1['nu'][:n_plt, :g_plt])

In [None]:
sns.clustermap(sim1['m'][:n_plt, :g_plt], norm=mpl.colors.SymLogNorm(linthresh=1))

In [None]:
plot_genotype_similarity(sim1['gamma'][:s_plt, :g_plt])

In [None]:
plt.hist(sim1['epsilon'], bins=50)
None

In [None]:
plt.hist(sim1['alpha'], bins=20)
None

## Estimation

### Initialization

In [None]:
from sklearn.cluster import AgglomerativeClustering

def cluster_genotypes(
    gamma, thresh, progress=False
):
    

    clust = pd.Series(
        AgglomerativeClustering(
            n_clusters=None,
            affinity="precomputed",
            linkage="complete",
            distance_threshold=thresh,
        )
        .fit(squareform(genotype_pdist2(gamma, progress=progress)))
        .labels_
    )

    return clust

def initialize_parameters_by_clustering_samples(
    y, m, thresh, additional_strains_factor=0.5, progress=False,
):
    n, g = y.shape

    sample_genotype = (y + 1) / (m + 2)
    clust = cluster_genotypes(sample_genotype, thresh=thresh, progress=progress)

    y_total = (
        pd.DataFrame(pd.DataFrame(y))
        .groupby(clust)
        .sum()
        .values
    )
    m_total = (
        pd.DataFrame(pd.DataFrame(m))
        .groupby(clust)
        .sum()
        .values
    )
    clust_genotype = (y_total + 1) / (m_total + 2)
    additional_haplotypes = int(
        additional_strains_factor * clust_genotype.shape[0]
    )


    gamma_init = pd.concat(
        [
            pd.DataFrame(clust_genotype),
            pd.DataFrame(np.ones((additional_haplotypes, g)) * 0.5),
        ]
    ).values

    s_init = gamma_init.shape[0]
    pi_init = np.ones((n, s_init))
    for i in range(n):
        pi_init[i, clust[i]] = s_init - 1
    pi_init /= pi_init.sum(1, keepdims=True)

    assert (~np.isnan(gamma_init)).all()

    return gamma_init, pi_init

In [None]:
g_fit = 1000  # sim1['y'].shape[1]
n_fit = sim1['y'].shape[0]

sim1_gamma_init, sim1_pi_init = initialize_parameters_by_clustering_samples(
    sim1['y'][:n_fit, :g_fit],
    sim1['m'][:n_fit, :g_fit],
    thresh=0.1,
    additional_strains_factor=0.,
    progress=True,
)

print(sim1_pi_init.shape)

In [None]:
plot_genotype(sim1_gamma_init[:s_plt, :g_plt])

In [None]:
plot_genotype_similarity(sim1_gamma_init)

In [None]:
plot_community(sim1_pi_init[:n_plt, :s_plt])

### Fitting

In [None]:
# estimation.py

def estimate_parameters(
    model,
    data,
    dtype=torch.float32,
    device='cpu',
    initialize_params=None,
    maxiter=10000,
    lag=100,
    lr=1e-0,
    clip_norm=100,
    progress=True,
    **model_kwargs,
):
    conditioned_model = condition_model(
        model,
        data=data,
        dtype=dtype,
        device=device,
        **model_kwargs,
    )
    if initialize_params is None:
        initialize_params = {}

    _guide = pyro.infer.autoguide.AutoLaplaceApproximation(
        conditioned_model,
        init_loc_fn=pyro.infer.autoguide.initialization.init_to_value(
            values=all_torch(**initialize_params, dtype=dtype, device=device)
        ),
    )
    opt = pyro.optim.Adamax({"lr": lr}, {"clip_norm": clip_norm})
    svi = pyro.infer.SVI(
        conditioned_model,
        _guide,
        opt,
        loss=pyro.infer.JitTrace_ELBO()
    )
    pyro.clear_param_store()

    history = []
    pbar = tqdm(range(maxiter), disable=(not progress))
    try:
        for i in pbar:
            elbo = svi.step()

            if np.isnan(elbo):
                raise RuntimeError("ELBO NaN?")

            # Fit tracking
            history.append(elbo)

            # Reporting/Breaking
            if (i % 10 == 0):
                if i > lag:
                    delta = history[-2] - history[-1]
                    delta_lag = (history[-lag] - history[-1]) / lag
                    if delta_lag <= 0:
                        if progress:
                            info("Converged")
                        break
                    pbar.set_postfix({
                        'ELBO': history[-1],
                        'delta': delta,
                        f'lag{lag}': delta_lag,
                    })
    except KeyboardInterrupt:
        info("Interrupted")
        pass
    finally:         
        est = pyro.infer.Predictive(conditioned_model, guide=_guide, num_samples=1)()
        est = {
            k: est[k].detach().cpu().numpy().mean(0).squeeze()
            for k in est.keys()
        }
    return est, history

In [None]:
s_fit = sim1_gamma_init.shape[0]
initialize_params = dict(gamma=sim1_gamma_init, pi=sim1_pi_init)

sim1_fit1, history = estimate_parameters(
    model_betabinomial_missing,
    data=dict(y=sim1['y'][:, :g_fit], m=sim1['m'][:, :g_fit]),
    n=n_fit,
    g=g_fit,
    s=s_fit,
    gamma_hyper=0.1,
    pi_hyper=1.0,
    rho_hyper=0.5,
    mu_hyper_mean=5,
    mu_hyper_scale=5.,
    m_hyper_r=10.,
    delta_hyper_temp=0.1,
    delta_hyper_p=0.9,
    alpha_hyper_hyper_mean=100.,
    alpha_hyper_hyper_scale=10.,
    alpha_hyper_scale=0.5,
    epsilon_hyper_alpha=1.5,
    epsilon_hyper_beta=1.5 / 0.01,
    initialize_params=initialize_params,
    device='cpu',
    lag=100,
    lr=1e-1,
)

In [None]:
def plot_loss(trace):
    trace = np.array(trace)
    plt.plot((trace - trace.min()))
    plt.yscale('log')
    
plot_loss(history)

### Merging Strains

In [None]:
def merge_similar_genotypes(
    gamma, pi, thresh, delta=None,
):
    if delta is None:
        delta = np.ones_like(gamma)

    clust = cluster_genotypes(gamma * delta, thresh=thresh)
    gamma_mean = (
        pd.DataFrame(pd.DataFrame(gamma))
        .groupby(clust)
        .apply(lambda x: sp.special.expit(sp.special.logit(x)).mean(0))
        .values
    )
    pi_sum = (
        pd.DataFrame(pd.DataFrame(pi))
        .groupby(clust, axis='columns')
        .sum()
        .values
    )

    return gamma_mean, pi_sum

sim1_fit1_gamma_merge, sim1_fit1_pi_merge = merge_similar_genotypes(
    sim1_fit1['gamma'],
    sim1_fit1['pi'],
    thresh=0.1,
)

# print(sim1_gamma_init.shape[0], sim1_fit1['gamma'].shape[0], sim1_fit1_gamma_merge.shape[0])
print(sim1_fit1['gamma'].shape[0], sim1_fit1_gamma_merge.shape[0])

## Evaluation

In [None]:
def mask_missing_genotype(gamma, delta):
    return sp.special.expit(sp.special.logit(gamma) * delta)

In [None]:
sim1_gamma_adjusted = mask_missing_genotype(sim1['gamma'][:, :g_fit], sim1['delta'][:, :g_fit])
sim1_fit1_gamma_adjusted = mask_missing_genotype(sim1_fit1['gamma'], sim1_fit1['delta'])

### Ground Truth

#### Visualization

In [None]:
def plot_genotype_comparison(data=None, **kwargs):
    stacked = pd.concat([
        pd.DataFrame(data[k], index=[f'{k}_{i}' for i in range(data[k].shape[0])])
        for k in data
    ])
    kw = dict(xticklabels=1)
    kw.update(kwargs)
    plot_genotype(stacked, **kw)

def plot_community_comparison(data=None, **kwargs):
    stacked = pd.concat([
        pd.DataFrame(data[k], columns=[f'{k}_{i}' for i in range(data[k].shape[1])])
        for k in data
    ], axis=1)
    kw = dict(xticklabels=1)
    kw.update(kwargs)
    plot_community(stacked, **kw)


In [None]:
plot_genotype_comparison(
    data=dict(
        true=sim1_gamma_adjusted[:, :g_plt],
#         fit=sim1_fit1['gamma'][:, :g_plt],
        adj=sim1_fit1_gamma_adjusted[:, :g_plt],
#         init=sim1_gamma_init,
#         merg=sim1_fit1_gamma_merge,
    ),
    linkage_kw=dict(progress=True),
)

In [None]:
plot_community_comparison(
    data=dict(
        true=sim1['pi'],
        fit=sim1_fit1['pi'],
#         init=sim1_pi_init,
#         merg=sim1_fit1_pi_merge,
    ),
)

In [None]:
plt.scatter(sim1['epsilon'], sim1_fit1['epsilon'])
plt.plot([0, 0.04], [0, 0.04])

In [None]:
plt.scatter(sim1['alpha'], sim1_fit1['alpha'])
plt.plot([0, 200], [0, 200])

In [None]:
sns.heatmap(sim1_fit1['delta'], vmin=0, vmax=1)

In [None]:
plt.scatter(sim1['mu'], sim1_fit1['mu'])
plt.plot([0, 40], [0, 40])

In [None]:
# TODO: Plot comparing genotype accuracy to true strain abundance
# colored by mean entropy of the estimated genotype masked by delta

#### Fit scores

In [None]:
def binary_entropy(p):
    q = 1 - p
    ent = -(p * np.log2(p) + q * np.log2(q))
    return ent

def sum_binary_entropy(p, normalize=False, axis=None):
    q = 1 - p
    ent = np.sum(-(p * np.log2(p) + q * np.log2(q)), axis=axis)
    if normalize:
        ent = ent / p.shape[axis]
    return ent

def mean_masked_genotype_entropy(gamma, delta):
    return (binary_entropy(gamma) * delta).mean(1)

In [None]:
# Quality of genotypes.
# For each true genotype, compare its genotype to the
# best inferred genotype.
# This is our score for the quality of the genotype inferences.

from scipy.spatial.distance import cdist
import pandas as pd

def best_genotype_hits(gammaA, gammaB):
    g = gammaA.shape[1]
    dist = pd.DataFrame(cdist(gammaA, gammaB, metric='cityblock'))
    return dist.idxmin(axis=1), dist.min(axis=1) / g

best_hit, best_dist = best_genotype_hits(sim1_gamma_adjusted[:, :g_fit], sim1_fit1_gamma_adjusted[:, :g_fit])

print('weighted_mean_distance:', (best_dist * sim1['pi'].mean(0)).sum())
plt.scatter((sim1['pi'] * sim1['mu'].reshape(-1, 1)).sum(0), best_dist)

In [None]:
# Quality of abundance estimates
# Compare BC distance matrices for inferences to the true distance matrix

bc_sim = 1 - pdist(sim1['pi'], metric='braycurtis')
bc_fit = 1 - pdist(sim1_fit1['pi'], metric='braycurtis')

plt.scatter(
    bc_sim,
    bc_fit,
    marker='.',
    alpha=0.2,
)

print(np.abs((bc_sim - bc_fit)).mean())

### No Ground Truth

#### Visualization

In [None]:
# Strains that are not representative of true haplotypes
# are high entropy (even after masking with delta)
# and have low estimated total coverage.

best_true_strain, best_true_strain_dist = best_genotype_hits(sim1_fit1_gamma_adjusted[:, :g_fit], sim1_gamma_adjusted[:, :g_fit])
best_true_strain_dist

plt.scatter((sim1_fit1['pi'] * sim1_fit1['mu'].reshape((-1, 1))).sum(0), best_true_strain_dist, c=mean_masked_genotype_entropy(sim1_fit1['gamma'], sim1_fit1['delta']))

In [None]:
plot_genotype(sim1_fit1_gamma_adjusted[:, :g_plt], linkage_kw=dict(progress=True))

In [None]:
plot_community(sim1_fit1['pi'])

#### Confidence Scores

In [None]:
plt.hist(mean_masked_genotype_entropy(sim1_fit1['gamma'][:, :g_plt], sim1_fit1['delta'][:, :g_plt]))
None

In [None]:
plt.hist(sim1_fit1['alpha'], bins=20)
None

In [None]:
plot_genotype(sim1_fit1['gamma'][mean_masked_genotype_entropy(sim1_fit1['gamma'], sim1_fit1['delta']) < 0.1, :g_fit])

## Experiment: Average and variation in accuracy

In [None]:
def simulate_then_run_workflow(n, g, s, sim_kwargs, init_kwargs, fit_kwargs, seed=1):
    pyro.util.set_rng_seed(seed)
    sim = simulate(
        condition_model(
            model_betabinomial_missing,
            n=n,
            g=g,
            s=s,
            **sim_kwargs,
        )
    )
    gamma_init, pi_init = initialize_parameters_by_clustering_samples(
        sim['y'],
        sim['m'],
        **init_kwargs
    )
    s_fit = gamma_init.shape[0]
    fit, history = estimate_parameters(
        model_betabinomial_missing,
        data=dict(y=sim['y'], m=sim['m']),
        n=n,
        g=g,
        s=s_fit,
        initialize_params=dict(gamma=gamma_init, pi=pi_init),
        **fit_kwargs,
    )
    
    sim_gamma_adj = mask_missing_genotype(sim['gamma'], sim['delta'])
    fit_gamma_adj = mask_missing_genotype(fit['gamma'], fit['delta'])
    best_hit, best_dist = best_genotype_hits(sim_gamma_adj, fit_gamma_adj)
    weighted_mean_genotype_error = (best_dist * sim['pi'].mean(0)).sum()
    
    bc_sim = 1 - pdist(sim['pi'], metric='braycurtis')
    bc_fit = 1 - pdist(fit['pi'], metric='braycurtis')
    mean_beta_diversity_error = (np.abs((bc_sim - bc_fit)).mean())
    
    strain_count_error = s_fit - s
    
    return weighted_mean_genotype_error, mean_beta_diversity_error, strain_count_error

In [None]:
replicates = 5
seed_start = 0
for seed in range(seed_start, seed_start + replicates):
    generr, comperr, scounterr = simulate_then_run_workflow(
        100,
        200,
        20,
        sim_kwargs=dict(
            data=dict(
                alpha_hyper_mean=100.
            ),
            gamma_hyper=0.01,
            delta_hyper_temp=0.01,
            delta_hyper_p=0.7,
            pi_hyper=0.5,
            rho_hyper=2.,
            mu_hyper_mean=1.,
            mu_hyper_scale=0.5,
            m_hyper_r=10.,
            alpha_hyper_scale=0.5,
            epsilon_hyper_alpha=1.5,
            epsilon_hyper_beta=1.5/0.01,
            device='cpu'
        ),
        init_kwargs=dict(
            thresh=0.1,
            additional_strains_factor=0.,
            progress=False,
        ),
        fit_kwargs=dict(
            gamma_hyper=0.1,
            pi_hyper=1.0,
            rho_hyper=0.5,
            mu_hyper_mean=5,
            mu_hyper_scale=5.,
            m_hyper_r=10.,
            delta_hyper_temp=0.1,
            delta_hyper_p=0.9,
            alpha_hyper_hyper_mean=100.,
            alpha_hyper_hyper_scale=10.,
            alpha_hyper_scale=0.5,
            epsilon_hyper_alpha=1.5,
            epsilon_hyper_beta=1.5 / 0.01,
            device='cpu',
            lag=10,
            lr=1e-0,
            progress=False
        ),
        seed=seed
    )
    print(seed, generr, comperr, scounterr)

## Estimation on Real Data

### Data Loading

In [None]:
# data.py