## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sfacts as sf

In [None]:
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import scipy as sp
import pyro
import pyro.distributions as dist
import torch
from functools import partial
from tqdm import tqdm
import xarray as xr
import warnings
from torch.jit import TracerWarning

In [None]:
n = 100
p = torch.tensor(1 - 1e-8, requires_grad=True)

torch.allclose(
    sf.model.stickbreaking_betas_to_probs(torch.ones(n) * p),
    sf.model.stickbreaking_betas_to_probs2(torch.ones(n) * p),
)

In [None]:
(
    torch.autograd.grad(sf.model.stickbreaking_betas_to_probs(torch.ones(n) * p)[2], inputs=p)[0],
    torch.autograd.grad(sf.model.stickbreaking_betas_to_probs2(torch.ones(n) * p)[2], inputs=p)[0]
)

## Library

### `__init__`

In [None]:
%%writefile sfacts/__init__.py

from sfacts import (logging_util, pyro_util, pandas_util, model, genotype, plot, estimation, evaluation, workflow, data, app)

### logging_util

In [None]:
%%writefile sfacts/logging_util.py

from datetime import datetime
import sys


def info(*msg, quiet=False):
    now = datetime.now()
    if not quiet:
        print(f'[{now}]', *msg, file=sys.stderr, flush=True)

### pyro_util

In [None]:
%%writefile sfacts/pyro_util.py

import pyro
import pyro.distributions as dist
import torch
from sfacts.logging_util import info


def as_torch(x, dtype=None, device=None):
    # Cast inputs and set device
    return torch.tensor(x, dtype=dtype, device=device)

def all_torch(dtype=None, device=None, **kwargs):
    # Cast inputs and set device
    return {k: as_torch(kwargs[k], dtype=dtype, device=device) for k in kwargs}

def shape_info(model, *args, **kwargs):
    _trace = pyro.poutine.trace(model).get_trace(*args, **kwargs)
    _trace.compute_log_prob()
    info(_trace.format_shapes())

### pandas_util

In [None]:
%%writefile sfacts/pandas_util.py

import pandas as pd

def idxwhere(condition):
    return list(condition[condition].index)

### model

In [None]:
%%writefile sfacts/model.py

from sfacts.pyro_util import as_torch, all_torch
import pyro
import pyro.distributions as dist
import torch
from functools import partial
from torch.nn.functional import pad as torch_pad

def stickbreaking_betas_to_probs(beta):
    beta1m_cumprod = (1 - beta).cumprod(-1)
    return torch_pad(beta, (0, 1), value=1) * torch_pad(beta1m_cumprod, (1, 0), value=1)


def stickbreaking_betas_to_probs2(beta):
    """I thought this might be more stable, but it turns out the gradient is NOT more stable."""
    log_beta1m = torch.log(1 - beta)
    log_beta1m_cumprod = log_beta1m.cumsum(-1)
    log_beta_pad = torch_pad(torch.log(beta), (0, 1), value=0)
    log_beta1m_cumprod_pad = torch_pad(log_beta1m_cumprod, (1, 0), value=0)
#     beta1m_cumprod = (1 - beta).cumprod(-1)
#     return torch_pad(beta, (0, 1), value=1) * torch_pad(beta1m_cumprod, (1, 0), value=1)
    return torch.exp(log_beta_pad + log_beta1m_cumprod_pad)



def NegativeBinomialReparam(mu, r, eps):
    p = torch.clamp(1. / ((r / mu) + 1.), eps, 1 - eps)
    return dist.NegativeBinomial(
        total_count=r,
        probs=p,
    )

def model(
    n,
    g,
    s,
    gamma_hyper=1.,
    delta_hyper_temp=0.1,
    delta_hyper_p=0.9,
    rho_hyper=1.,
    pi_hyper=1.,
    alpha_hyper_hyper_mean=100.,
    alpha_hyper_hyper_scale=1.,
    alpha_hyper_scale=0.5,
    epsilon_hyper_alpha=1.5,
    epsilon_hyper_beta=1.5 / 0.01,
    mu_hyper_mean=1.,
    mu_hyper_scale=1.,
#     m_hyper_r=1.,
    m_eps=1e-5,
    dtype=torch.float32,
    device='cpu',
):
    (
        gamma_hyper,
        delta_hyper_temp,
        delta_hyper_p,
        rho_hyper,
        pi_hyper,
        alpha_hyper_hyper_mean,
        alpha_hyper_hyper_scale,
        alpha_hyper_scale,
        epsilon_hyper_alpha,
        epsilon_hyper_beta,
        mu_hyper_mean,
        mu_hyper_scale,
#         m_hyper_r,
        m_eps,
    ) = (
        as_torch(x, dtype=dtype, device=device)
        for x in [
            gamma_hyper,
            delta_hyper_temp,
            delta_hyper_p,
            rho_hyper,
            pi_hyper,
            alpha_hyper_hyper_mean,
            alpha_hyper_hyper_scale,
            alpha_hyper_scale,
            epsilon_hyper_alpha,
            epsilon_hyper_beta,
            mu_hyper_mean,
            mu_hyper_scale,
#             m_hyper_r,
            m_eps,
        ]
    )
    
    _zero = as_torch(0, dtype=dtype, device=device).squeeze()
    _one = as_torch(1, dtype=dtype, device=device).squeeze()

    # Genotypes
#     delta_hyper_p = pyro.sample('delta_hyper_p', dist.Beta(1., 1.))
    with pyro.plate('position', g, dim=-1):
        with pyro.plate('strain', s, dim=-2):
            gamma = pyro.sample(
                'gamma', dist.RelaxedBernoulli(temperature=gamma_hyper, logits=_zero)
            )
            # Position presence/absence
            delta = pyro.sample(
                'delta', dist.RelaxedBernoulli(temperature=delta_hyper_temp, probs=delta_hyper_p)
            )
#             delta = pyro.sample(
#                 'delta', dist.Beta(delta_hyper_p * delta_hyper_temp, (1 - delta_hyper_p) * delta_hyper_temp)
#             )
    
    # Meta-community composition
    rho_betas = pyro.sample('rho_betas', dist.Beta(1, rho_hyper).expand([s - 1]).to_event())
    rho = pyro.deterministic('rho', stickbreaking_betas_to_probs(rho_betas))
#     rho = pyro.sample('rho', dist.Dirichlet(rho_hyper * torch.ones(s, dtype=dtype, device=device)))

    alpha_hyper_mean = pyro.sample('alpha_hyper_mean', dist.LogNormal(loc=torch.log(alpha_hyper_hyper_mean), scale=alpha_hyper_hyper_scale))
    with pyro.plate('sample', n, dim=-1):
        # Community composition
        pi = pyro.sample('pi', dist.Dirichlet(pi_hyper * rho, validate_args=False))
        # Sequencing error
        epsilon = pyro.sample('epsilon', dist.Beta(epsilon_hyper_alpha, epsilon_hyper_beta)).unsqueeze(-1)
        alpha = pyro.sample('alpha', dist.LogNormal(loc=torch.log(alpha_hyper_mean), scale=alpha_hyper_scale)).unsqueeze(-1)
        # Sample coverage
        mu = pyro.sample('mu', dist.LogNormal(loc=torch.log(mu_hyper_mean), scale=mu_hyper_scale))
        
    # Depth at each position
    nu = pyro.deterministic("nu", pi @ delta)
    m_hyper_r = pyro.sample('m_hyper_r', dist.LogNormal(loc=_one, scale=_one))
    m = pyro.sample('m', NegativeBinomialReparam(nu * mu.reshape((-1,1)), m_hyper_r, eps=m_eps).to_event())

    # Expected fractions of each allele at each position
    p_noerr = pyro.deterministic('p_noerr', pi @ (gamma * delta) / nu)
    p = pyro.deterministic('p',
        (1 - epsilon / 2) * (p_noerr) +
        (epsilon / 2) * (1 - p_noerr)
    )
    
    # Observation
    y = pyro.sample(
        'y',
        dist.BetaBinomial(
            concentration1=alpha * p,
            concentration0=alpha * (1 - p),
            total_count=m,
        ).to_event(),
    )
    
def condition_model(model, data=None, device='cpu', dtype=torch.float32, **model_kwargs):
    if data is None:
        data = {}
        
    conditioned_model = partial(
        pyro.condition(
            model,
            data=all_torch(**data, dtype=dtype, device=device),
        ),
        **model_kwargs,
        dtype=dtype,
        device=device,
    )
    return conditioned_model
    
def simulate(model):
    obs = pyro.infer.Predictive(model, num_samples=1)()
    obs = {
        k: obs[k].detach().cpu().numpy().squeeze()
        for k in obs.keys()
    }
    return obs

### genotype

In [None]:
%%writefile sfacts/genotype.py

import numpy as np
from scipy.spatial.distance import pdist, squareform
import scipy as sp
from scipy.cluster.hierarchy import linkage
from tqdm import tqdm

def prob_to_sign(gamma):
    return gamma * 2 - 1

def genotype_distance(x, y):
    x = prob_to_sign(x)
    y = prob_to_sign(y)
    dist = ((x - y) / 2) ** 2
    weight = (x * y) ** 2
    wmean_dist = ((weight * dist).mean()) / ((weight.mean()))
    return wmean_dist

def sign_genotype_distance(x, y):
    dist = ((x - y) / 2) ** 2
    weight = (x * y) ** 2
    wmean_dist = ((weight * dist).mean()) / ((weight.mean()))
    return wmean_dist

def genotype_pdist(gamma, progress=False):
    metric = sign_genotype_distance
    X = np.asarray(gamma * 2 - 1)
    m = X.shape[0]
    dm = np.empty((m * (m - 1)) // 2)
    k = 0
    with tqdm(total=len(dm), disable=(not progress)) as pbar:
        for i in range(0, m - 1):
            for j in range(i + 1, m):
                dm[k] = metric(X[i], X[j])
                k = k + 1
                pbar.update()
    return dm
    

def counts_to_p_estimate(y, m, pseudo=1):
    return (y + pseudo) / (m + pseudo * 2)

def genotype_linkage(gamma, progress=False, **kwargs):
    dmat = genotype_pdist(gamma, progress=progress)
    kw = dict(method='complete')
    kw.update(kwargs)
    return linkage(dmat, **kw), dmat

def mask_missing_genotype(gamma, delta):
    return sp.special.expit(sp.special.logit(gamma) * delta)

### plot

In [None]:
%%writefile sfacts/plot.py

import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib as mpl
from sfacts.genotype import genotype_linkage, prob_to_sign
from scipy.spatial.distance import squareform
import pandas as pd
import numpy as np

def calculate_clustermap_dims(nx, ny, scalex=0.15, scaley=0.02, dwidth=0.2, dheight=1.0):
    mwidth = nx * scalex
    mheight = ny * scaley
    fwidth = mwidth + dwidth
    fheight = mheight + dheight
    dendrogram_ratio = (dwidth / fwidth, dheight / fheight)
    return fwidth, fheight, dendrogram_ratio
    

def plot_genotype(gamma, linkage_kw=None, scalex=0.15, scaley=0.02, dwidth=0.2, dheight=1.0, **kwargs):
    if linkage_kw is None:
        linkage_kw = {}
    linkage, _ = genotype_linkage(gamma, **linkage_kw)
    
    gamma_t = gamma.T
    ny, nx = gamma_t.shape
    fwidth, fheight, dendrogram_ratio = calculate_clustermap_dims(
        nx, ny, scalex=scalex, scaley=scaley, dwidth=dwidth, dheight=dheight,
    )
    
    kw = dict(
        vmin=-1,
        vmax=1,
        cmap='coolwarm',
        dendrogram_ratio=dendrogram_ratio,
        col_linkage=linkage,
        figsize=(fwidth, fheight),
        xticklabels=1,
        yticklabels=0,
    )
    kw.update(kwargs)
    grid = sns.clustermap(prob_to_sign(gamma_t), **kw)
    grid.cax.set_visible(False)
    return grid
    
def plot_missing(delta, scalex=0.15, scaley=0.02, dwidth=0.2, dheight=1.0, **kwargs):
    delta_t = delta.T
    ny, nx = delta_t.shape
    fwidth, fheight, dendrogram_ratio = calculate_clustermap_dims(
        nx, ny, scalex=scalex, scaley=scaley, dwidth=dwidth, dheight=dheight,
    )
    
    kw = dict(
        vmin=0, vmax=1, dendrogram_ratio=dendrogram_ratio, figsize=(fwidth, fheight), xticklabels=1, yticklabels=0,
    )
    kw.update(kwargs)
    grid = sns.clustermap(delta_t, **kw)
    grid.cax.set_visible(False)
    return grid
    
def plot_community(pi, scalex=0.2, scaley=0.1, dwidth=0.2, dheight=1.0, **kwargs):
    ny, nx = pi.shape
    fwidth, fheight, dendrogram_ratio = calculate_clustermap_dims(
        nx, ny, scalex=scalex, scaley=scaley, dwidth=dwidth, dheight=dheight,
    )
    
    kw = dict(
        metric='cosine', vmin=0, vmax=1, dendrogram_ratio=dendrogram_ratio, figsize=(fwidth, fheight), xticklabels=1,
    )
    kw.update(kwargs)
    grid = sns.clustermap(pi, **kw)
    grid.cax.set_visible(False)
    return grid
    
def plot_genotype_similarity(gamma, linkage_kw=None, **kwargs):
    if linkage_kw is None:
        linkage_kw = {}
    linkage, dmat = genotype_linkage(gamma, **linkage_kw)
    dmat = squareform(dmat)
    
    nx = ny = gamma.shape[0]
    fwidth, fheight, dendrogram_ratio = calculate_clustermap_dims(
        nx, ny, scalex=0.15, scaley=0.15, dwidth=0.5, dheight=0.5
    )
    
    kw = dict(
        vmin=0, vmax=1, dendrogram_ratio=dendrogram_ratio, row_linkage=linkage, col_linkage=linkage, figsize=(fwidth, fheight), xticklabels=1, yticklabels=1,
    )
    kw.update(kwargs)
    grid = sns.clustermap(1 - dmat, **kw)
    grid.cax.set_visible(False)
    return grid
    
def plot_genotype_comparison(data=None, **kwargs):
    stacked = pd.concat([
        pd.DataFrame(data[k], index=[f'{k}_{i}' for i in range(data[k].shape[0])])
        for k in data
    ])
    kw = dict(xticklabels=1)
    kw.update(kwargs)
    return plot_genotype(stacked, **kw)

def plot_community_comparison(data=None, **kwargs):
    stacked = pd.concat([
        pd.DataFrame(data[k], columns=[f'{k}_{i}' for i in range(data[k].shape[1])])
        for k in data
    ], axis=1)
    kw = dict(xticklabels=1)
    kw.update(kwargs)
    return plot_community(stacked, **kw)

def plot_loss_history(trace):
    trace = np.array(trace)
    plt.plot((trace - trace.min()))
    plt.yscale('log')

### estimation

In [None]:
%%writefile sfacts/estimation.py

from sklearn.cluster import AgglomerativeClustering
from sklearn.decomposition import non_negative_factorization
from sfacts.genotype import genotype_pdist, mask_missing_genotype
from sfacts.pyro_util import all_torch

import pandas as pd
import numpy as np
import scipy as sp
from scipy.spatial.distance import squareform

import pyro
import pyro.distributions as dist
import torch

from tqdm import tqdm
from sfacts.logging_util import info

from sfacts.model import condition_model

def cluster_genotypes(
    gamma, thresh, progress=False, precomputed_pdist=None
):
    
    if precomputed_pdist is None:
        compressed_dmat = genotype_pdist(gamma, progress=progress)
    else:
        compressed_dmat = precomputed_pdist

    clust = pd.Series(
        AgglomerativeClustering(
            n_clusters=None,
            affinity="precomputed",
            linkage="complete",
            distance_threshold=thresh,
        )
        .fit(squareform(compressed_dmat))
        .labels_
    )

    return clust, compressed_dmat

def initialize_parameters_by_clustering_samples(
    y, m, thresh, additional_strains_factor=0.5, progress=False, precomputed_pdist=None,
):
    n, g = y.shape

    sample_genotype = (y + 1) / (m + 2)
    clust, cdmat = cluster_genotypes(sample_genotype, thresh=thresh, progress=progress, precomputed_pdist=precomputed_pdist)

    y_total = (
        pd.DataFrame(pd.DataFrame(y))
        .groupby(clust)
        .sum()
        .values
    )
    m_total = (
        pd.DataFrame(pd.DataFrame(m))
        .groupby(clust)
        .sum()
        .values
    )
    clust_genotype = (y_total + 1) / (m_total + 2)
    additional_haplotypes = int(
        additional_strains_factor * clust_genotype.shape[0]
    )


    gamma_init = pd.concat(
        [
            pd.DataFrame(clust_genotype),
            pd.DataFrame(np.ones((additional_haplotypes, g)) * 0.5),
        ]
    ).values

    s_init = gamma_init.shape[0]
    pi_init = np.ones((n, s_init))
    for i in range(n):
        pi_init[i, clust[i]] = s_init - 1
    pi_init /= pi_init.sum(1, keepdims=True)

    assert (~np.isnan(gamma_init)).all()

    return gamma_init, pi_init, cdmat

def initialize_parameters_by_nmf(
    y, m, s, progress=False, solver='mu', alpha=100., l1_ratio=1.0, tol=1e-2, **kwargs
):    
    n, g = y.shape

    # Fit to counts of both reference and alternative alleles by stacking them.
    stacked_metagenotype = np.concatenate([y, m - y], axis=1)
    pi_unnorm, gamma_unnorm, _ = non_negative_factorization(
        stacked_metagenotype,
        n_components=s,
        solver=solver,
        verbose=int(progress),
        alpha=alpha,
        l1_ratio=l1_ratio,
        tol=tol,
        **kwargs
    )
    
    # TODO: Find a more principled way to convert pi_unnorm into pi_init.
    pi_init = (pi_unnorm + 1) / (pi_unnorm + 1).sum(1, keepdims=True)
    gamma_init = ((gamma_unnorm[:, :g] + 1) / (gamma_unnorm[:, :g] + gamma_unnorm[:, -g:] + 2))

    return gamma_init, pi_init, None


def estimate_parameters(
    model,
    data,
    dtype=torch.float32,
    device='cpu',
    initialize_params=None,
    maxiter=10000,
    lagA=20,
    lagB=100,
    opt=pyro.optim.Adamax({"lr": 1e-2}, {"clip_norm": 100}),
    progress=True,
    **model_kwargs,
):
    conditioned_model = condition_model(
        model,
        data=data,
        dtype=dtype,
        device=device,
        **model_kwargs,
    )
    if initialize_params is None:
        initialize_params = {}

    _guide = pyro.infer.autoguide.AutoLaplaceApproximation(
        conditioned_model,
        init_loc_fn=pyro.infer.autoguide.initialization.init_to_value(
            values=all_torch(**initialize_params, dtype=dtype, device=device)
        ),
    )
    svi = pyro.infer.SVI(
        conditioned_model,
        _guide,
        opt,
        loss=pyro.infer.JitTrace_ELBO()
    )
    pyro.clear_param_store()

    history = []
    pbar = tqdm(range(maxiter), disable=(not progress))
    try:
        for i in pbar:
            elbo = svi.step()

            if np.isnan(elbo):
                pbar.close()
                raise RuntimeError("ELBO NaN?")

            # Fit tracking
            history.append(elbo)

            # Reporting/Breaking
            if (i % 10 == 0):
                if i > lagB:
                    delta = history[-2] - history[-1]
                    delta_lagA = (history[-lagA] - history[-1]) / lagA
                    delta_lagB = (history[-lagB] - history[-1]) / lagB
                    if (delta_lagA <= 0) and (delta_lagB <= 0):
                        if progress:
                            pbar.close()
                            info("Converged")
                        break
                    pbar.set_postfix({
                        'ELBO': history[-1],
                        'delta': delta,
                        f'lag{lagA}': delta_lagA,
                        f'lag{lagB}': delta_lagB,
                    })
    except KeyboardInterrupt:
        pbar.close()
        info("Interrupted")
        pass
    est = pyro.infer.Predictive(conditioned_model, guide=_guide, num_samples=1)()
    est = {
        k: est[k].detach().cpu().numpy().mean(0).squeeze()
        for k in est.keys()
    }
    
    if device.startswith("cuda"):
#         info(
#             "CUDA available mem: {}".format(
#                 torch.cuda.get_device_properties(0).total_memory
#             ),
#         )
#         info("CUDA reserved mem: {}".format(torch.cuda.memory_reserved(0)))
#         info("CUDA allocated mem: {}".format(torch.cuda.memory_allocated(0)))
#         info(
#             "CUDA free mem: {}".format(
#                 torch.cuda.memory_reserved(0) - torch.cuda.memory_allocated(0)
#             )
#         )
        torch.cuda.empty_cache()

    return est, history


def merge_similar_genotypes(
    gamma, pi, thresh, delta=None, progress=False,
):
    if delta is None:
        delta = np.ones_like(gamma)
        
    gamma_adjust = mask_missing_genotype(gamma, delta)

    clust, dmat = cluster_genotypes(gamma_adjust, thresh=thresh, progress=progress)
    gamma_mean = (
        pd.DataFrame(pd.DataFrame(gamma_adjust))
        .groupby(clust)
        .apply(lambda x: sp.special.expit(sp.special.logit(x)).mean(0))
        .values
    )
    delta_mean = (
        pd.DataFrame(pd.DataFrame(delta))
        .groupby(clust)
        .mean()
        .values
    )
    pi_sum = (
        pd.DataFrame(pd.DataFrame(pi))
        .groupby(clust, axis='columns')
        .sum()
        .values
    )
    
    return gamma_mean, pi_sum, delta_mean

### evaluation

In [None]:
%%writefile sfacts/evaluation.py
from scipy.spatial.distance import cdist, pdist
import pandas as pd
import numpy as np


def binary_entropy(p):
    q = 1 - p
    ent = -(p * np.log2(p) + q * np.log2(q))
    return ent

def sum_binary_entropy(p, normalize=False, axis=None):
    q = 1 - p
    ent = np.sum(-(p * np.log2(p) + q * np.log2(q)), axis=axis)
    if normalize:
        ent = ent / p.shape[axis]
    return ent

def mean_masked_genotype_entropy(gamma, delta):
    return (binary_entropy(gamma) * delta).mean(1)

def sample_mean_masked_genotype_entropy(pi, gamma, delta):
    return (pi @ mean_masked_genotype_entropy(gamma, delta).reshape((-1, 1))).squeeze()

def match_genotypes(gammaA, gammaB):
    g = gammaA.shape[1]
    dist = pd.DataFrame(cdist(gammaA, gammaB, metric='cityblock'))
    return dist.idxmin(axis=1), dist.min(axis=1) / g

def _rmse(x, y):
    return np.sqrt(np.square(x - y).mean())

def _rss(x, y):
    return np.sqrt(np.square(x - y).sum())

def community_accuracy_test(pi_sim, pi_fit, reps=99):
    bc_sim = 1 - pdist(pi_sim, metric='braycurtis')
    bc_fit = 1 - pdist(pi_fit, metric='braycurtis')
    err = _rmse(bc_sim, bc_fit)
    
    null = []
    n = len(bc_sim)
    for i in range(reps):
        bc_sim_permute = np.random.permutation(bc_sim)
        null.append(_rmse(bc_sim, bc_sim_permute))
    null = np.array(null)
    
    return err, null, err / np.mean(null), (np.sort(null) < err).mean()

def metacommunity_composition_rss(pi_sim, pi_fit):
    mean_sim = pi_sim.mean(0)
    mean_fit = pi_fit.mean(0)
    s_sim = mean_sim.shape[0]
    s_fit = mean_fit.shape[0]
    s = max(s_sim, s_fit)
    mean_sim = np.sort(np.pad(mean_sim, pad_width=(0, s - s_sim)))
    mean_fit = np.sort(np.pad(mean_fit, pad_width=(0, s - s_fit)))
    return _rss(mean_sim, mean_fit)

### data

In [None]:
%%writefile sfacts/data.py

from sfacts.logging_util import info
from sfacts.pandas_util import idxwhere
import xarray as xr


class Metagenotype():
    def __init__(self, *args, **kwargs):
        data = xr.Dataset(*args, **kwargs)

def load_input_data(allpaths):
    data = []
    for path in allpaths:
        info(path)
        d = xr.open_dataarray(path).squeeze()
        info(f"Shape: {d.sizes}.")
        data.append(d)
    info("Concatenating data from {} files.".format(len(data)))
    data = xr.concat(data, "library_id", fill_value=0)
    info(f"Finished concatenating data: {data.sizes}")
    return data

def select_informative_positions(data, incid_thresh):
    minor_allele_incid = (data > 0).mean("library_id").min("allele")
    informative_positions = idxwhere(
        minor_allele_incid.to_series() > incid_thresh
    )
    return informative_positions

### workflow

In [None]:
%%writefile sfacts/workflow.py

import pyro
from sfacts.pandas_util import idxwhere
from sfacts.model import model, simulate, condition_model
from sfacts.estimation import (
    initialize_parameters_by_clustering_samples,
    initialize_parameters_by_nmf,
    estimate_parameters,
    merge_similar_genotypes
)
from sfacts.genotype import mask_missing_genotype
from sfacts.evaluation import (
    match_genotypes,
    sample_mean_masked_genotype_entropy,
    community_accuracy_test,
    metacommunity_composition_rss,
)
from sfacts.data import load_input_data, select_informative_positions
import time
import numpy as np
from sfacts.logging_util import info

def fit_to_data(
    y,
    m,
    fit_kwargs,
    initialize='nmf',
    initialize_kwargs=None,
    postclust=True,
    postclust_kwargs=None,
    seed=1,
    quiet=False,
    additional_conditioning_data=None
):
    if additional_conditioning_data is None:
        additional_conditioning_data = {}

    n, g = y.shape
    info(f"Setting RNG seed to {seed}.", quiet=quiet)
    pyro.util.set_rng_seed(seed)
    if initialize == 'nmf':
        info(f"Initializing {n} samples and {g} positions using NMF.", quiet=quiet)
        assert initialize_kwargs is not None
        gamma_init, pi_init, _ = initialize_parameters_by_nmf(
            y,
            m,
            random_state=seed,
            **initialize_kwargs
        )
        initialize_params=dict(gamma=gamma_init, pi=pi_init)
        s_fit = gamma_init.shape[0]
        info(f"Initialized {s_fit} strains in {n} samples.", quiet=quiet)
    elif initialize == 'clust':
        info(f"Initializing {n} samples and {g} positions using hierarchical clustering.", quiet=quiet)
        assert initialize_kwargs is not None
        gamma_init, pi_init, _ = initialize_parameters_by_clustering_samples(
            y,
            m,
            **initialize_kwargs
        )
        initialize_params=dict(gamma=gamma_init, pi=pi_init)
        s_fit = gamma_init.shape[0]
        info(f"Initialized {s_fit} strains in {n} samples.", quiet=quiet)
    elif not initialize:
        initialize_params = None
        s_fit = fit_kwargs.pop('s')
    else:
        raise NotImplementedError(f"Initializing strategy: '{initialize}' not known.")

    info(f"Optimizing model parameters.", quiet=quiet)
    info(f"Setting RNG seed to {seed}.", quiet=quiet)
    pyro.util.set_rng_seed(seed)
    fit, history = estimate_parameters(
        model,
        data=dict(y=y, m=m, **additional_conditioning_data),
        n=n,
        g=g,
        s=s_fit,
        initialize_params=initialize_params,
        **fit_kwargs,
    )
    
    if postclust:
        info(f"Dereplicating highly similar strains.", quiet=quiet)
        merge_gamma, merge_pi, merge_delta = merge_similar_genotypes(
            fit['gamma'],
            fit['pi'],
            delta=fit['delta'],
            **postclust_kwargs,
        )
        mrg = fit.copy()
        mrg['gamma'] = merge_gamma
        mrg['pi'] = merge_pi
        mrg['delta'] = merge_delta
        s_mrg = mrg['gamma'].shape[0]
        info(f"Original {s_fit} strains down to {s_mrg} after dereplication.", quiet=quiet)
    else:
        mrg = fit
    info(f"Finished fitting to data.", quiet=quiet)
        
    return mrg, fit, history

def simulate_fit_and_evaluate(
    s_sim,
    n_sim,
    g_sim,
    n_fit,
    g_fit,
    sim_kwargs,
    fit_kwargs,
    seed_sim=1,
    seed_fit=1,
    preclust=True,
    preclust_kwargs=None,
    postclust=True,
    postclust_kwargs=None,
    quiet=False,
):
    info(f"Setting RNG seed to {seed_sim}.", quiet=quiet)
    pyro.util.set_rng_seed(seed_sim)
    info(f"Simulating data from model.", quiet=quiet)
    sim = simulate(
        condition_model(
            model,
            n=n_sim,
            g=g_sim,
            s=s_sim,
            **sim_kwargs,
        )
    )
    
    info(f"Starting runtime clock.", quiet=quiet)
    start_time = time.time()
    mrg, fit, history = fit_to_data(
        sim['y'][:n_fit, :g_fit],
        sim['m'][:n_fit, :g_fit],
        fit_kwargs=fit_kwargs,
        preclust=preclust,
        preclust_kwargs=preclust_kwargs,
        postclust=postclust,
        postclust_kwargs=postclust_kwargs,
        seed=seed_fit,
        quiet=quiet,
    )
    end_time = time.time()
    info(f"Stopping runtime clock.", quiet=quiet)
    
    info(f"Calculating statistics.", quiet=quiet)
    s_mrg = mrg['gamma'].shape[0]
    
    sim_gamma_adj = mask_missing_genotype(sim['gamma'][:, :g_fit], sim['delta'][:, :g_fit])
    mrg_gamma_adj = mask_missing_genotype(mrg['gamma'], mrg['delta'])
    best_hit, best_dist = match_genotypes(sim_gamma_adj, mrg_gamma_adj)
    weighted_mean_genotype_error = (best_dist * sim['pi'][:n_fit].mean(0)).sum()
    runtime = end_time - start_time
    
    _, _, beta_diversity_error_ratio, _ = (
        community_accuracy_test(sim['pi'][:n_fit], mrg['pi'])
    )
    
    metacommunity_composition_error = metacommunity_composition_rss(sim['pi'], mrg['pi'])
    
    mean_sample_weighted_genotype_entropy = (
        sample_mean_masked_genotype_entropy(mrg['pi'], mrg['gamma'], mrg['delta']).mean()
    )
    info(f"Finished calculating statistics.", quiet=quiet)
    
    return (
        weighted_mean_genotype_error,
        beta_diversity_error_ratio,
        metacommunity_composition_error,
        mean_sample_weighted_genotype_entropy,
        runtime,
        sim,
        mrg
    )


def filter_data(
    data, 
    incid_thresh=0.1,
    cvrg_thresh=0.15,
):
    info("Filtering positions.")
    informative_positions = select_informative_positions(
        data, incid_thresh
    )
    npos_available = len(informative_positions)
    info(
        f"Found {npos_available} informative positions with minor "
        f"allele incidence of >{incid_thresh}"
    )

    info("Filtering libraries.")
    suff_cvrg_samples = idxwhere(
        (
            (
                data.sel(position=informative_positions).sum(["allele"]) > 0
            ).mean("position")
            > cvrg_thresh
        ).to_series()
    )
    nlibs = len(suff_cvrg_samples)
    info(
        f"Found {nlibs} libraries with >{cvrg_thresh:0.1%} "
        f"of informative positions covered."
    )
    return informative_positions, suff_cvrg_samples

def sample_positions(
    informative_positions,
    npos=1000,
    seed=None,
):
    if seed is not None:
        info(f"Setting RNG seed to {seed}.")
        np.random.seed(seed)
    npos_available = len(informative_positions)
    _npos = min(npos, npos_available)
    info(f"Randomly sampling {npos} positions.")
    position_ss = np.random.choice(
        informative_positions,
        size=_npos,
        replace=False,
    )
    info(f"Finished sampling.")
    return position_ss  

def filter_subsample_and_fit(
    data,
    incid_thresh=0.1,
    cvrg_thresh=0.15,
    npos=1000,
    seed=1,
    **fit_to_data_kwargs
):
    info(f"Full data shape: {data.sizes}.")
    informative_positions, suff_cvrg_samples = filter_data(
        data, incid_thresh=incid_thresh, cvrg_thresh=cvrg_thresh
    )
    position_ss = sample_positions(informative_positions, npos, seed=seed)
    info("Constructing input data.")
    data_fit = data.sel(library_id=suff_cvrg_samples, position=position_ss)
    m_ss = data_fit.sum("allele")
    n, g_ss = m_ss.shape
    y_obs_ss = data_fit.sel(allele="alt")  
    mrg_ss, fit_ss, history = fit_to_data(
        y_obs_ss.values,
        m_ss.values,
        seed=seed,
        **fit_to_data_kwargs,
    )
    return mrg_ss, data_fit, history
    

def filter_subsample_fit_and_refit_genotypes(
    data,
    fit_kwargs,
    incid_thresh=0.1,
    cvrg_thresh=0.15,
    npos=1000,
    seed=1,
    **kwargs
):
    info(f"Full data shape: {data.sizes}.")
    informative_positions, suff_cvrg_samples = filter_data(
        data, incid_thresh=incid_thresh, cvrg_thresh=cvrg_thresh
    )
    position_ss = sample_positions(informative_positions, npos, seed=seed)
    info("Constructing input data.")
    data_filt = data.sel(library_id=suff_cvrg_samples)
    data_ss = data_filt.sel(position=position_ss)
    m_ss = data_ss.sum("allele")
    n, g_ss = m_ss.shape
    y_obs_ss = data_ss.sel(allele="alt")
    mrg_ss, fit_ss, history = fit_to_data(
        y_obs_ss.values,
        m_ss.values,
        seed=seed,
        fit_kwargs=fit_kwargs,
        **kwargs,
    )
    
    info(f"Refitting genotypes at all positions")
    s = mrg_ss['gamma'].shape[0]
    refit_kwargs = fit_kwargs.copy()
    if s in refit_kwargs:
        del refit_kwargs['s']
    n = len(suff_cvrg_samples)
    g_total = len(informative_positions)
    fixed = mrg_ss.copy()
    for k in ['gamma', 'delta', 'nu', 'm', 'p_noerr', 'p', 'y', 'rho']:
        del fixed[k]
        
    y = data_filt.sel(allele="alt").values
    m = data_filt.sum("allele").values
    out = fixed.copy()
    gamma_out = []
    delta_out = []
    nu_out = []
    m_out = []
    p_noerr_out = []
    p_out = []
    y_out = []
    nwindows = g_total // npos + 1
    for window_i, j_start in enumerate(range(0, g_total, npos)):
        window_ip1 = window_i + 1
        info(f"Fitting genotype window {window_ip1} of {nwindows}.")
        j_stop = min(j_start + g_ss, g_total)
        refit, history = estimate_parameters(
            model,
            data=dict(y=y[:, j_start:j_stop], m=m[:, j_start:j_stop], **fixed),
            n=n,
            g=j_stop - j_start,
            s=s,
            **refit_kwargs,
        )
        gamma_out.append(refit['gamma'])
        delta_out.append(refit['delta'])
        nu_out.append(refit['nu'])
        m_out.append(refit['m'])
        p_noerr_out.append(refit['p_noerr'])
        p_out.append(refit['p'])
        y_out.append(refit['y'])
        info(f"Finished fitting genotype window {window_ip1} of {nwindows}.")
    info(f"Finished all windows.")
        
    out['gamma'] = np.concatenate(gamma_out, axis=1)
    out['delta'] = np.concatenate(delta_out, axis=1)
    out['nu'] = np.concatenate(nu_out, axis=1)
    out['m'] = np.concatenate(m_out, axis=1)
    out['p_noerr'] = np.concatenate(p_noerr_out, axis=1)
    out['p'] = np.concatenate(p_out, axis=1)
    out['y'] = np.concatenate(y_out, axis=1)
    info(f"Finished constructing arrays.")

    return out, data_filt, informative_positions, position_ss

### app

In [None]:
%%writefile sfacts/app.py
#!/usr/bin/env python3

import sys
import argparse
import warnings
import xarray as xr
import pandas as pd
import sfacts


def parse_args(argv):
    p = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    # Input
    p.add_argument(
        "pileup",
        nargs="+",
        help="""
Single, fully processed, pileup table in NetCDF format
with the following dimensions:
    * library_id
    * position
    * allele
                        """,
    )

    # Shape of the model
    p.add_argument("--nstrains", metavar="INT", type=int, default=1000)
    p.add_argument(
        "--npos",
        metavar="INT",
        default=2000,
        type=int,
        help=("Number of positions to sample for model fitting."),
    )

    # Data filtering
    p.add_argument(
        "--incid-thresh",
        metavar="FLOAT",
        type=float,
        default=0.02,
        help=(
            "Minimum fraction of samples that must have the minor allele "
            "for the position to be considered 'informative'."
        ),
    )
    p.add_argument(
        "--cvrg-thresh",
        metavar="FLOAT",
        type=float,
        default=0.5,
        help=(
            "Minimum fraction of 'informative' positions with counts "
            "necessary for sample to be included."
        ),
    )

    # Regularization
    p.add_argument(
        "--gamma-hyper",
        metavar="FLOAT",
        default=1e-2,
        type=float,
        help=("Ambiguity regularization parameter."),
    )
    p.add_argument(
        "--pi-hyper",
        metavar="FLOAT",
        default=1e-1,
        type=float,
        help=(
            "Heterogeneity regularization parameter (will be scaled by 1 / s)."
        ),
    )
    p.add_argument(
        "--rho-hyper",
        metavar="FLOAT",
        default=1e0,
        type=float,
        help=("Diversity regularization parameter."),
    )
    p.add_argument(
        "--epsilon-hyper", metavar="FLOAT", default=0.01, type=float
    )
    p.add_argument(
        "--epsilon",
        metavar="FLOAT",
        default=None,
        type=float,
        help=("Fixed error rate for all samples."),
    )
    p.add_argument("--alpha-hyper", metavar="FLOAT", default=100.0, type=float)
    p.add_argument(
        "--alpha",
        metavar="FLOAT",
        default=None,
        type=float,
        help=("Fixed concentration for all samples."),
    )
    p.add_argument(
        "--collapse",
        metavar="FLOAT",
        default=0.0,
        type=float,
        help=("Merge strains with a cosine distance of less than this value."),
    )

    # Fitting
    p.add_argument("--random-seed", default=0, type=int, help=("FIXME"))
    p.add_argument("--max-iter", default=10000, type=int, help=("FIXME"))
    p.add_argument("--lag", default=50, type=int, help=("FIXME"))
    p.add_argument("--stop-at", default=5.0, type=float, help=("FIXME"))
    p.add_argument("--learning-rate", default=1e-0, type=float, help=("FIXME"))
    p.add_argument("--clip-norm", default=100.0, type=float, help=("FIXME"))

    # Hardware
    p.add_argument("--device", default="cpu", help=("PyTorch device name."))

    # Output
    p.add_argument(
        "--outpath",
        metavar="PATH",
        help=("Path for genotype fraction output."),
    )

    args = p.parse_args(argv)

    return args


if __name__ == "__main__":
    warnings.filterwarnings(
        "error",
        message="Encountered NaN: loss",
        category=UserWarning,
        # module="trace_elbo",  # FIXME: What is the correct regex for module?
        lineno=217,
    )
    warnings.filterwarnings(
        "ignore",
        message="CUDA initialization: Found no NVIDIA",
        category=UserWarning,
        lineno=130,
    )
    warnings.filterwarnings(
        "ignore",
        message="torch.tensor results are registered as constants",
        category=torch.jit.TracerWarning,
        # module="trace_elbo",  # FIXME: What is the correct regex for module?
        lineno=95,
    )

    args = parse_args(sys.argv[1:])
    info(args)

    info(f"Setting random seed: {args.random_seed}")
    np.random.seed(args.random_seed)

    info("Loading input data.")
    data = _load_input_data(args.pileup)
    info(f"Full data shape: {data.sizes}.")

    info("Filtering positions.")
    informative_positions = select_informative_positions(
        data, args.incid_thresh
    )
    npos_available = len(informative_positions)
    info(
        f"Found {npos_available} informative positions with minor "
        f"allele incidence of >{args.incid_thresh}"
    )
    npos = min(args.npos, npos_available)
    info(f"Randomly sampling {npos} positions.")
    position_ss = np.random.choice(
        informative_positions,
        size=npos,
        replace=False,
    )

    info("Filtering libraries.")
    suff_cvrg_samples = idxwhere(
        (
            (
                data.sel(position=informative_positions).sum(["allele"]) > 0
            ).mean("position")
            > args.cvrg_thresh
        ).to_series()
    )
    nlibs = len(suff_cvrg_samples)
    info(
        f"Found {nlibs} libraries with >{args.cvrg_thresh:0.1%} "
        f"of informative positions covered."
    )

    info("Constructing input data.")
    data_fit = data.sel(library_id=suff_cvrg_samples, position=position_ss)
    m_ss = data_fit.sum("allele")
    n, g_ss = m_ss.shape
    y_obs_ss = data_fit.sel(allele="alt")

    info("Optimizing model parameters.")
    mapest1, history1 = find_map(
        model_fit,
        init=as_torch_all(
            gamma=init_genotype,
            pi=init_frac,
            dtype=torch.float32,
            device=args.device,
        ),
        lag=args.lag,
        stop_at=args.stop_at,
        learning_rate=args.learning_rate,
        max_iter=args.max_iter,
        clip_norm=args.clip_norm,
    )
    if args.device.startswith("cuda"):
        info(
            "CUDA available mem: {}".format(
                torch.cuda.get_device_properties(0).total_memory
            ),
        )
        info("CUDA reserved mem: {}".format(torch.cuda.memory_reserved(0)))
        info("CUDA allocated mem: {}".format(torch.cuda.memory_allocated(0)))
        info(
            "CUDA free mem: {}".format(
                torch.cuda.memory_reserved(0) - torch.cuda.memory_allocated(0)
            )
        )
        torch.cuda.empty_cache()

    info("Finished fitting model.")
    result = xr.Dataset(
        {
            "gamma": (["strain", "position"], mapest3["gamma"]),
            "rho": (["strain"], mapest3["rho"]),
            "alpha_hyper": ([], mapest3["alpha_hyper"]),
            "pi": (["library_id", "strain"], mapest3["pi"]),
            "epsilon": (["library_id"], mapest3["epsilon"]),
            "rho_hyper": ([], mapest3["rho_hyper"]),
            "epsilon_hyper": ([], mapest3["epsilon_hyper"]),
            "pi_hyper": ([], mapest3["pi_hyper"]),
            "alpha": (["library_id"], mapest3["alpha"]),
            "p_noerr": (["library_id", "position"], mapest3["p_noerr"]),
            "p": (["library_id", "position"], mapest3["p"]),
            "y": (["library_id", "position"], y_obs_ss),
            "m": (["library_id", "position"], m_ss),
            "elbo_trace": (["iteration"], history1),
        },
        coords=dict(
            strain=np.arange(s_collapse),
            position=data_fit.position,
            library_id=data_fit.library_id,
        ),
    )

    if args.outpath:
        info("Saving results.")
        result.to_netcdf(
            args.outpath,
            encoding=dict(
                gamma=dict(dtype="float32", zlib=True, complevel=6),
                pi=dict(dtype="float32", zlib=True, complevel=6),
                p_noerr=dict(dtype="float32", zlib=True, complevel=6),
                p=dict(dtype="float32", zlib=True, complevel=6),
                y=dict(dtype="uint16", zlib=True, complevel=6),
                m=dict(dtype="uint16", zlib=True, complevel=6),
                elbo_trace=dict(dtype="float32", zlib=True, complevel=6),
            ),
        )


## Model Specification

In [None]:
epsilon_hyper_alpha, epsilon_hyper_beta = 1.5, 1.5 / 0.01
plt.hist(pyro.sample('epsilon_hyper', dist.Beta(epsilon_hyper_alpha, epsilon_hyper_beta).expand([10000])).cpu().numpy(), bins=100)
None

In [None]:
plt.hist(pyro.sample('test', sf.model.NegativeBinomialReparam(torch.tensor(10.), r=torch.tensor(1.), eps=torch.tensor(1e-5)).expand([1000])).numpy())

In [None]:
sf.pyro_util.shape_info(sf.model.model, n=100, g=200, s=20)

## Simulation

### SimShape-1: Small study

In [None]:
seed = 1
pyro.util.set_rng_seed(seed)

n_sim = 100
g_sim = 5000
s_sim = 20

sim1 = sf.model.simulate(
    sf.model.condition_model(
        sf.model.model,
        data=dict(
            alpha_hyper_mean=100.
        ),
        n=n_sim,
        g=g_sim,
        s=s_sim,
        gamma_hyper=0.01,
        delta_hyper_temp=0.01,
        delta_hyper_p=0.7,
        pi_hyper=0.5,
        rho_hyper=10.,
        mu_hyper_mean=2.,
        mu_hyper_scale=0.5,
        m_hyper_r=10.,
        alpha_hyper_scale=0.5,
        epsilon_hyper_alpha=1.5,
        epsilon_hyper_beta=1.5/0.01,
        device='cpu'
    )
)

## Visualization

In [None]:
n_plt = 100
g_plt = 200
s_plt = 20

In [None]:
sf.plot.plot_community(sim1['pi'][:s_plt, :n_plt])

In [None]:
sf.plot.plot_genotype(
    sf.genotype.counts_to_p_estimate(
        sim1['y'][:n_plt, :g_plt],
        sim1['m'][:n_plt, :g_plt]),
    linkage_kw=dict(progress=True)
)

In [None]:
sf.plot.plot_genotype_similarity(sf.genotype.counts_to_p_estimate(sim1['y'][:n_plt, :g_plt], sim1['m'][:n_plt, :g_plt]), linkage_kw=dict(progress=True))

In [None]:
sf.plot.plot_genotype(sim1['gamma'][:s_plt, :g_plt])

In [None]:
sf.plot.plot_missing(sim1['delta'][:s_plt, :g_plt])

In [None]:
sf.plot.plot_missing(sim1['nu'][:n_plt, :g_plt])

In [None]:
sns.clustermap(sim1['m'][:n_plt, :g_plt], norm=mpl.colors.SymLogNorm(linthresh=1))

In [None]:
sf.plot.plot_genotype_similarity(sim1['gamma'][:s_plt, :g_plt])

In [None]:
plt.hist(sim1['epsilon'], bins=50)
None

In [None]:
plt.hist(sim1['alpha'], bins=20)
None

## Estimation

### Initialization

In [None]:
g_fit = 1000  # sim1['y'].shape[1]
n_fit = sim1['y'].shape[0]

sim1_gamma_init, sim1_pi_init, sim1_cdmat = sf.estimation.initialize_parameters_by_clustering_samples(
    sim1['y'][:n_fit, :g_fit],
    sim1['m'][:n_fit, :g_fit],
    thresh=0.05,
    additional_strains_factor=0.,
    progress=True,
)

print(sim1_pi_init.shape)

In [None]:
sf.plot.plot_genotype(sim1_gamma_init[:s_plt, :g_plt])

In [None]:
sf.plot.plot_genotype_similarity(sim1_gamma_init)

In [None]:
sf.plot.plot_community(sim1_pi_init[:n_plt, :s_plt])

### Fitting

In [None]:
s_fit = sim1_gamma_init.shape[0]
initialize_params = dict(gamma=sim1_gamma_init, pi=sim1_pi_init)

sim1_fit1, history = sf.estimation.estimate_parameters(
    sf.model.model,
    data=dict(y=sim1['y'][:, :g_fit], m=sim1['m'][:, :g_fit]),
    n=n_fit,
    g=g_fit,
    s=s_fit,
    gamma_hyper=0.1,
    pi_hyper=1.0,
    rho_hyper=0.5,
    mu_hyper_mean=5,
    mu_hyper_scale=5.,
    m_hyper_r=10.,
    delta_hyper_temp=0.1,
    delta_hyper_p=0.9,
    alpha_hyper_hyper_mean=100.,
    alpha_hyper_hyper_scale=10.,
    alpha_hyper_scale=0.5,
    epsilon_hyper_alpha=1.5,
    epsilon_hyper_beta=1.5 / 0.01,
    initialize_params=initialize_params,
    device='cpu',
    lag=100,
    lr=1e-1,
)

### Merging Strains

In [None]:
sim1_fit1_gamma_merge, sim1_fit1_pi_merge, sim1_fit1_delta_merge  = sf.estimation.merge_similar_genotypes(
    sim1_fit1['gamma'],
    sim1_fit1['pi'],
    delta=sim1_fit1['delta'],
    thresh=0.1,
)

# print(sim1_gamma_init.shape[0], sim1_fit1['gamma'].shape[0], sim1_fit1_gamma_merge.shape[0])
print(sim1_fit1['gamma'].shape[0], sim1_fit1_gamma_merge.shape[0])

## Evaluation

In [None]:
sim1_gamma_adjusted = sf.genotype.mask_missing_genotype(sim1['gamma'][:, :g_fit], sim1['delta'][:, :g_fit])
sim1_fit1_gamma_adjusted = sf.genotype.mask_missing_genotype(sim1_fit1['gamma'], sim1_fit1['delta'])

### Ground Truth

#### Visualization

In [None]:
sf.plot.plot_genotype_comparison(
    data=dict(
        true=sim1_gamma_adjusted[:, :g_plt],
#         fit=sim1_fit1['gamma'][:, :g_plt],
        adj=sim1_fit1_gamma_adjusted[:, :g_plt],
#         init=sim1_gamma_init,
#         merg=sim1_fit1_gamma_merge,
    ),
    linkage_kw=dict(progress=True),
)

In [None]:
sf.plot.plot_community_comparison(
    data=dict(
        true=sim1['pi'],
        fit=sim1_fit1['pi'],
#         init=sim1_pi_init,
#         merg=sim1_fit1_pi_merge,
    ),
)

In [None]:
plt.scatter(sim1['epsilon'], sim1_fit1['epsilon'])
plt.plot([0, 0.04], [0, 0.04])

In [None]:
plt.scatter(sim1['alpha'], sim1_fit1['alpha'])
plt.plot([0, 200], [0, 200])

In [None]:
sns.heatmap(sim1_fit1['delta'], vmin=0, vmax=1)

In [None]:
plt.scatter(sim1['mu'], sim1_fit1['mu'])
plt.plot([0, 40], [0, 40])

In [None]:
# TODO: Plot comparing genotype accuracy to true strain abundance
# colored by mean entropy of the estimated genotype masked by delta

#### Fit scores

In [None]:
plt.scatter(sim1_fit1['alpha'], sample_mean_masked_genotype_entropy(sim1_fit1['pi'], sim1_fit1['gamma'], sim1_fit1['delta']))
sample_mean_masked_genotype_entropy(sim1_fit1['pi'], sim1_fit1['gamma'], sim1_fit1['delta']).mean()

In [None]:
best_hit, best_dist = match_genotypes(sim1_gamma_adjusted[:, :g_fit], sim1_fit1_gamma_adjusted[:, :g_fit])

print('weighted_mean_distance:', (best_dist * sim1['pi'].mean(0)).sum())
plt.scatter((sim1['pi'] * sim1['mu'].reshape(-1, 1)).sum(0), best_dist)

In [None]:
bc_sim = 1 - pdist(sim1['pi'], metric='braycurtis')
bc_fit = 1 - pdist(sim1_fit1['pi'], metric='braycurtis')
plt.scatter(
    bc_sim,
    bc_fit,
    marker='.',
    alpha=0.2,
)

community_accuracy_test(sim1['pi'], sim1_fit1['pi'])

### No Ground Truth

#### Visualization

In [None]:
# Strains that are not representative of true haplotypes
# are high entropy (even after masking with delta)
# and have low estimated total coverage.

best_true_strain, best_true_strain_dist = match_genotypes(sim1_fit1_gamma_adjusted[:, :g_fit], sim1_gamma_adjusted[:, :g_fit])
best_true_strain_dist

plt.scatter((sim1_fit1['pi'] * sim1_fit1['mu'].reshape((-1, 1))).sum(0), best_true_strain_dist, c=mean_masked_genotype_entropy(sim1_fit1['gamma'], sim1_fit1['delta']))

In [None]:
plot_genotype(sim1_fit1_gamma_adjusted[:, :g_plt], linkage_kw=dict(progress=True))

In [None]:
plot_community(sim1_fit1['pi'])

#### Confidence Scores

In [None]:
plt.hist(mean_masked_genotype_entropy(sim1_fit1['gamma'][:, :g_plt], sim1_fit1['delta'][:, :g_plt]))
None

In [None]:
plt.hist(sim1_fit1['alpha'], bins=20)
None

In [None]:
plot_genotype(sim1_fit1['gamma'][mean_masked_genotype_entropy(sim1_fit1['gamma'], sim1_fit1['delta']) < 0.1, :g_fit])