In [None]:
%load_ext autoreload
%autoreload 2

## Library

In [None]:
%%writefile sfacts/__init__.py

from sfacts.logging_util import *
from sfacts.pyro_util import *
from sfacts.metagenotype_model import *
from sfacts.genotype import *
from sfacts.plot import *
from sfacts.estimation import *
from sfacts.evaluation import *
from sfacts.workflow import *

gamma_to_sign = prob_to_sign

### logging_util

In [None]:
%%writefile sfacts/logging_util.py

from datetime import datetime
import sys


def info(*msg):
    now = datetime.now()
    print(f'[{now}]', *msg, file=sys.stderr, flush=True)

### pyro_util

In [None]:
%%writefile sfacts/pyro_util.py

import pyro
import pyro.distributions as dist
import torch
from sfacts.logging_util import info


def as_torch(x, dtype=None, device=None):
    # Cast inputs and set device
    return torch.tensor(x, dtype=dtype, device=device)

def all_torch(dtype=None, device=None, **kwargs):
    # Cast inputs and set device
    return {k: as_torch(kwargs[k], dtype=dtype, device=device) for k in kwargs}

def shape_info(model, *args, **kwargs):
    _trace = pyro.poutine.trace(model).get_trace(*args, **kwargs)
    _trace.compute_log_prob()
    info(_trace.format_shapes())

### metagenotype_model

In [None]:
%%writefile sfacts/metagenotype_model.py

from sfacts.pyro_util import as_torch, all_torch
import pyro
import pyro.distributions as dist
import torch
from functools import partial


def NegativeBinomialReparam(mu, r, eps=1e-5):
    p = torch.clamp(1. / ((r / mu) + 1.), min=eps, max=1. - eps)
    return dist.NegativeBinomial(
        total_count=r,
        probs=p
    )

def model(
    n,
    g,
    s,
    gamma_hyper=1.,
    delta_hyper_temp=0.1,
    delta_hyper_p=0.9,
    rho_hyper=1.,
    pi_hyper=1.,
    alpha_hyper_hyper_mean=100.,
    alpha_hyper_hyper_scale=1.,
    alpha_hyper_scale=0.5,
    epsilon_hyper_alpha=1.5,
    epsilon_hyper_beta=1.5 / 0.01,
    mu_hyper_mean=1.,
    mu_hyper_scale=1.,
    m_hyper_r=1.,
    dtype=torch.float32,
    device='cpu',
):
    
    (
        gamma_hyper,
        delta_hyper_temp,
        delta_hyper_p,
        rho_hyper,
        pi_hyper,
        alpha_hyper_hyper_mean,
        alpha_hyper_hyper_scale,
        alpha_hyper_scale,
        epsilon_hyper_alpha,
        epsilon_hyper_beta,
        mu_hyper_mean,
        mu_hyper_scale,
        m_hyper_r,
    ) = (
        as_torch(x, dtype=dtype, device=device)
        for x in [
            gamma_hyper,
            delta_hyper_temp,
            delta_hyper_p,
            rho_hyper,
            pi_hyper,
            alpha_hyper_hyper_mean,
            alpha_hyper_hyper_scale,
            alpha_hyper_scale,
            epsilon_hyper_alpha,
            epsilon_hyper_beta,
            mu_hyper_mean,
            mu_hyper_scale,
            m_hyper_r,
        ]
    )

    # Genotypes
    with pyro.plate('position', g, dim=-1):
        with pyro.plate('strain', s, dim=-2):
            gamma = pyro.sample(
                'gamma', dist.RelaxedBernoulli(temperature=gamma_hyper, logits=torch.zeros((1,), dtype=dtype, device=device).squeeze())
            )
            # Position presence/absence
            delta = pyro.sample(
                'delta', dist.RelaxedBernoulli(temperature=delta_hyper_temp, probs=delta_hyper_p)
            )
    
    # Meta-community composition
    rho = pyro.sample('rho', dist.Dirichlet(rho_hyper * torch.ones(s, dtype=dtype, device=device)))

    alpha_hyper_mean = pyro.sample('alpha_hyper_mean', dist.LogNormal(loc=torch.log(alpha_hyper_hyper_mean), scale=alpha_hyper_hyper_scale))
    alpha_hyper_scale = pyro.sample('alpha_hyper_scale', dist.LogNormal(loc=0, scale=1))
    with pyro.plate('sample', n, dim=-1):
        # Community composition
        pi = pyro.sample('pi', dist.RelaxedOneHotCategorical(temperature=pi_hyper, probs=rho))
        # Sample coverage
        mu = pyro.sample('mu', dist.LogNormal(loc=torch.log(mu_hyper_mean), scale=mu_hyper_scale))
        # Sequencing error
        epsilon = pyro.sample('epsilon', dist.Beta(epsilon_hyper_alpha, epsilon_hyper_beta)).unsqueeze(-1)
        alpha = pyro.sample('alpha', dist.LogNormal(loc=torch.log(alpha_hyper_mean), scale=alpha_hyper_scale)).unsqueeze(-1)
        
    # Depth at each position
    nu = pyro.deterministic("nu", pi @ delta)
    m = pyro.sample('m', NegativeBinomialReparam(nu * mu.reshape((-1,1)), m_hyper_r).to_event())

    # Expected fractions of each allele at each position
    p_noerr = pyro.deterministic('p_noerr', pi @ (gamma * delta) / nu)
    p = pyro.deterministic('p',
        (1 - epsilon / 2) * (p_noerr) +
        (epsilon / 2) * (1 - p_noerr)
    )
    
    # Observation
    y = pyro.sample(
        'y',
        dist.BetaBinomial(
            concentration1=alpha * p,
            concentration0=alpha * (1 - p),
            total_count=m
        ).to_event(),
    )
    
def condition_model(model, data=None, device='cpu', dtype=torch.float32, **model_kwargs):
    if data is None:
        data = {}
        
    conditioned_model = partial(
        pyro.condition(
            model,
            data=all_torch(**data, dtype=dtype, device=device),
        ),
        **model_kwargs,
        dtype=dtype,
        device=device,
    )
    return conditioned_model
    
def simulate(model):
    obs = pyro.infer.Predictive(model, num_samples=1)()
    obs = {
        k: obs[k].detach().cpu().numpy().squeeze()
        for k in obs.keys()
    }
    return obs

### genotype

In [None]:
%%writefile sfacts/genotype.py

import numpy as np
from scipy.spatial.distance import pdist, squareform
import scipy as sp
from scipy.cluster.hierarchy import linkage
from tqdm import tqdm

def prob_to_sign(gamma):
    return gamma * 2 - 1

def genotype_distance(x, y):
    x = prob_to_sign(x)
    y = prob_to_sign(y)
    dist = ((x - y) / 2) ** 2
    weight = (x * y) ** 2
    wmean_dist = ((weight * dist).mean()) / ((weight.mean()))
    return wmean_dist

def sign_genotype_distance(x, y):
    dist = ((x - y) / 2) ** 2
    weight = (x * y) ** 2
    wmean_dist = ((weight * dist).mean()) / ((weight.mean()))
    return wmean_dist

def genotype_pdist(gamma, progress=False):
    metric = sign_genotype_distance
    X = np.asarray(gamma * 2 - 1)
    m = X.shape[0]
    dm = np.empty((m * (m - 1)) // 2)
    k = 0
    with tqdm(total=len(dm), disable=(not progress)) as pbar:
        for i in range(0, m - 1):
            for j in range(i + 1, m):
                dm[k] = metric(X[i], X[j])
                k = k + 1
                pbar.update()
    return dm
    

def counts_to_p_estimate(y, m, pseudo=1):
    return (y + pseudo) / (m + pseudo * 2)

def genotype_linkage(gamma, progress=False, **kwargs):
    dmat = genotype_pdist(gamma, progress=progress)
    kw = dict(method='complete')
    kw.update(kwargs)
    return linkage(dmat, **kw), dmat

def mask_missing_genotype(gamma, delta):
    return sp.special.expit(sp.special.logit(gamma) * delta)

### plot

In [None]:
%%writefile sfacts/plot.py

import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib as mpl
from sfacts.genotype import genotype_linkage, prob_to_sign
from scipy.spatial.distance import squareform
import pandas as pd

def calculate_clustermap_dims(nx, ny, scalex=0.15, scaley=0.02, dwidth=0.2, dheight=1.0):
    mwidth = nx * scalex
    mheight = ny * scaley
    fwidth = mwidth + dwidth
    fheight = mheight + dheight
    dendrogram_ratio = (dwidth / fwidth, dheight / fheight)
    return fwidth, fheight, dendrogram_ratio
    

def plot_genotype(gamma, linkage_kw=None, **kwargs):
    if linkage_kw is None:
        linkage_kw = {}
    linkage, _ = genotype_linkage(gamma, **linkage_kw)
    
    gamma_t = gamma.T
    ny, nx = gamma_t.shape
    fwidth, fheight, dendrogram_ratio = calculate_clustermap_dims(
        nx, ny, scalex=0.15, scaley=0.02, dwidth=0.2, dheight=1.0
    )
    
    kw = dict(
        vmin=-1,
        vmax=1,
        cmap='coolwarm',
        dendrogram_ratio=dendrogram_ratio,
        col_linkage=linkage,
        figsize=(fwidth, fheight),
        xticklabels=1,
        yticklabels=0,
    )
    kw.update(kwargs)
    sns.clustermap(prob_to_sign(gamma_t), **kw)
    
def plot_missing(delta, **kwargs):
    delta_t = delta.T
    ny, nx = delta_t.shape
    fwidth, fheight, dendrogram_ratio = calculate_clustermap_dims(
        nx, ny, scalex=0.15, scaley=0.02, dwidth=0.2, dheight=1.0
    )
    
    kw = dict(
        vmin=0, vmax=1, dendrogram_ratio=dendrogram_ratio, figsize=(fwidth, fheight), xticklabels=1, yticklabels=0,
    )
    kw.update(kwargs)
    sns.clustermap(delta_t, **kw)
    
def plot_community(pi, **kwargs):
    ny, nx = pi.shape
    fwidth, fheight, dendrogram_ratio = calculate_clustermap_dims(
        nx, ny, scalex=0.2, scaley=0.1, dwidth=0.2, dheight=1.0
    )
    
    kw = dict(
        metric='cosine', vmin=0, vmax=1, dendrogram_ratio=dendrogram_ratio, figsize=(fwidth, fheight), xticklabels=1,
    )
    kw.update(kwargs)
    sns.clustermap(pi, **kw)
    
def plot_genotype_similarity(gamma, linkage_kw=None, **kwargs):
    if linkage_kw is None:
        linkage_kw = {}
    linkage, dmat = genotype_linkage(gamma, **linkage_kw)
    dmat = squareform(dmat)
    
    nx = ny = gamma.shape[0]
    fwidth, fheight, dendrogram_ratio = calculate_clustermap_dims(
        nx, ny, scalex=0.15, scaley=0.15, dwidth=0.5, dheight=0.5
    )
    
    kw = dict(
        vmin=0, vmax=1, dendrogram_ratio=dendrogram_ratio, row_linkage=linkage, col_linkage=linkage, figsize=(fwidth, fheight), xticklabels=1, yticklabels=1,
    )
    kw.update(kwargs)
    sns.clustermap(1 - dmat, **kw)
    
def plot_genotype_comparison(data=None, **kwargs):
    stacked = pd.concat([
        pd.DataFrame(data[k], index=[f'{k}_{i}' for i in range(data[k].shape[0])])
        for k in data
    ])
    kw = dict(xticklabels=1)
    kw.update(kwargs)
    plot_genotype(stacked, **kw)

def plot_community_comparison(data=None, **kwargs):
    stacked = pd.concat([
        pd.DataFrame(data[k], columns=[f'{k}_{i}' for i in range(data[k].shape[1])])
        for k in data
    ], axis=1)
    kw = dict(xticklabels=1)
    kw.update(kwargs)
    plot_community(stacked, **kw)

def plot_loss_history(trace):
    trace = np.array(trace)
    plt.plot((trace - trace.min()))
    plt.yscale('log')

### estimation

In [None]:
%%writefile sfacts/estimation.py

from sklearn.cluster import AgglomerativeClustering
from sfacts.genotype import genotype_pdist, mask_missing_genotype
from sfacts.pyro_util import all_torch

import pandas as pd
import numpy as np
import scipy as sp
from scipy.spatial.distance import squareform

import pyro
import pyro.distributions as dist
import torch

from tqdm import tqdm
from sfacts.logging_util import info

from sfacts.metagenotype_model import condition_model

def cluster_genotypes(
    gamma, thresh, progress=False
):
    

    clust = pd.Series(
        AgglomerativeClustering(
            n_clusters=None,
            affinity="precomputed",
            linkage="complete",
            distance_threshold=thresh,
        )
        .fit(squareform(genotype_pdist(gamma, progress=progress)))
        .labels_
    )

    return clust

def initialize_parameters_by_clustering_samples(
    y, m, thresh, additional_strains_factor=0.5, progress=False,
):
    n, g = y.shape

    sample_genotype = (y + 1) / (m + 2)
    clust = cluster_genotypes(sample_genotype, thresh=thresh, progress=progress)

    y_total = (
        pd.DataFrame(pd.DataFrame(y))
        .groupby(clust)
        .sum()
        .values
    )
    m_total = (
        pd.DataFrame(pd.DataFrame(m))
        .groupby(clust)
        .sum()
        .values
    )
    clust_genotype = (y_total + 1) / (m_total + 2)
    additional_haplotypes = int(
        additional_strains_factor * clust_genotype.shape[0]
    )


    gamma_init = pd.concat(
        [
            pd.DataFrame(clust_genotype),
            pd.DataFrame(np.ones((additional_haplotypes, g)) * 0.5),
        ]
    ).values

    s_init = gamma_init.shape[0]
    pi_init = np.ones((n, s_init))
    for i in range(n):
        pi_init[i, clust[i]] = s_init - 1
    pi_init /= pi_init.sum(1, keepdims=True)

    assert (~np.isnan(gamma_init)).all()

    return gamma_init, pi_init


def estimate_parameters(
    model,
    data,
    dtype=torch.float32,
    device='cpu',
    initialize_params=None,
    maxiter=10000,
    lag=100,
    lr=1e-0,
    clip_norm=100,
    progress=True,
    **model_kwargs,
):
    conditioned_model = condition_model(
        model,
        data=data,
        dtype=dtype,
        device=device,
        **model_kwargs,
    )
    if initialize_params is None:
        initialize_params = {}

    _guide = pyro.infer.autoguide.AutoLaplaceApproximation(
        conditioned_model,
        init_loc_fn=pyro.infer.autoguide.initialization.init_to_value(
            values=all_torch(**initialize_params, dtype=dtype, device=device)
        ),
    )
    opt = pyro.optim.Adamax({"lr": lr}, {"clip_norm": clip_norm})
    svi = pyro.infer.SVI(
        conditioned_model,
        _guide,
        opt,
        loss=pyro.infer.JitTrace_ELBO()
    )
    pyro.clear_param_store()

    history = []
    pbar = tqdm(range(maxiter), disable=(not progress))
    try:
        for i in pbar:
            elbo = svi.step()

            if np.isnan(elbo):
                raise RuntimeError("ELBO NaN?")

            # Fit tracking
            history.append(elbo)

            # Reporting/Breaking
            if (i % 10 == 0):
                if i > lag:
                    delta = history[-2] - history[-1]
                    delta_lag = (history[-lag] - history[-1]) / lag
                    if delta_lag <= 0:
                        if progress:
                            info("Converged")
                        break
                    pbar.set_postfix({
                        'ELBO': history[-1],
                        'delta': delta,
                        f'lag{lag}': delta_lag,
                    })
    except KeyboardInterrupt:
        info("Interrupted")
        pass
    finally:         
        est = pyro.infer.Predictive(conditioned_model, guide=_guide, num_samples=1)()
        est = {
            k: est[k].detach().cpu().numpy().mean(0).squeeze()
            for k in est.keys()
        }
    return est, history


def merge_similar_genotypes(
    gamma, pi, thresh, delta=None,
):
    if delta is None:
        delta = np.ones_like(gamma)
        
    gamma_adjust = mask_missing_genotype(gamma, delta)

    clust = cluster_genotypes(gamma_adjust, thresh=thresh)
    gamma_mean = (
        pd.DataFrame(pd.DataFrame(gamma_adjust))
        .groupby(clust)
        .apply(lambda x: sp.special.expit(sp.special.logit(x)).mean(0))
        .values
    )
    delta_mean = (
        pd.DataFrame(pd.DataFrame(delta))
        .groupby(clust)
        .mean()
        .values
    )
    pi_sum = (
        pd.DataFrame(pd.DataFrame(pi))
        .groupby(clust, axis='columns')
        .sum()
        .values
    )
    
    return gamma_mean, pi_sum, delta_mean

### evaluation

In [None]:
%%writefile sfacts/evaluation.py
from scipy.spatial.distance import cdist, pdist
import pandas as pd
import numpy as np


def binary_entropy(p):
    q = 1 - p
    ent = -(p * np.log2(p) + q * np.log2(q))
    return ent

def sum_binary_entropy(p, normalize=False, axis=None):
    q = 1 - p
    ent = np.sum(-(p * np.log2(p) + q * np.log2(q)), axis=axis)
    if normalize:
        ent = ent / p.shape[axis]
    return ent

def mean_masked_genotype_entropy(gamma, delta):
    return (binary_entropy(gamma) * delta).mean(1)

def sample_mean_masked_genotype_entropy(pi, gamma, delta):
    return (pi @ mean_masked_genotype_entropy(gamma, delta).reshape((-1, 1))).squeeze()

def match_genotypes(gammaA, gammaB):
    g = gammaA.shape[1]
    dist = pd.DataFrame(cdist(gammaA, gammaB, metric='cityblock'))
    return dist.idxmin(axis=1), dist.min(axis=1) / g

def _rmse(x, y):
    return np.sqrt(np.square(x - y).mean())

def community_accuracy_test(pi_sim, pi_fit, reps=99):
    bc_sim = 1 - pdist(pi_sim, metric='braycurtis')
    bc_fit = 1 - pdist(pi_fit, metric='braycurtis')
    err = _rmse(bc_sim, bc_fit)
    
    null = []
    n = len(bc_sim)
    for i in range(reps):
        bc_sim_permute = np.random.permutation(bc_sim)
        null.append(_rmse(bc_sim, bc_sim_permute))
    null = np.array(null)
    
    return err, null, err / np.mean(null), (np.sort(null) < err).mean()

### workflow

In [None]:
%%writefile sfacts/workflow.py

import pyro
from sfacts.metagenotype_model import model, simulate, condition_model
from sfacts.estimation import (
    initialize_parameters_by_clustering_samples,
    estimate_parameters,
    merge_similar_genotypes
)
from sfacts.genotype import mask_missing_genotype
from sfacts.evaluation import match_genotypes, sample_mean_masked_genotype_entropy, community_accuracy_test
import time

def fit_to_data(
    y,
    m,
    fit_kwargs,
    preclust=True,
    preclust_kwargs=None,
    postclust=True,
    postclust_kwargs=None,
    seed=1,
):
    n, g = y.shape
    pyro.util.set_rng_seed(seed)
    if preclust:
        assert preclust_kwargs is not None
        gamma_init, pi_init = initialize_parameters_by_clustering_samples(
            y,
            m,
            **preclust_kwargs
        )
        initialize_params=dict(gamma=gamma_init, pi=pi_init)
        s = gamma_init.shape[0]
    else:
        initialize_params = None
        s = fit_kwargs.pop('s')

    fit, history = estimate_parameters(
        model,
        data=dict(y=y, m=m),
        n=n,
        g=g,
        s=s,
        initialize_params=initialize_params,
        **fit_kwargs,
    )
    
    if postclust:
        merge_gamma, merge_pi, merge_delta = merge_similar_genotypes(
            fit['gamma'],
            fit['pi'],
            delta=fit['delta'],
            **postclust_kwargs,
        )
        mrg = fit.copy()
        mrg['gamma'] = merge_gamma
        mrg['pi'] = merge_pi
        mrg['delta'] = merge_delta
    else:
        mrg = fit
        
    return mrg, fit, history

def simulate_fit_and_evaluate(
    s_sim,
    n_sim,
    g_sim,
    n_fit,
    g_fit,
    sim_kwargs,
    fit_kwargs,
    seed=1,
    preclust=True,
    preclust_kwargs=None,
    postclust=True,
    postclust_kwargs=None,
    return_raw=False,
):
    pyro.util.set_rng_seed(seed)
    sim = simulate(
        condition_model(
            model,
            n=n_sim,
            g=g_sim,
            s=s_sim,
            **sim_kwargs,
        )
    )
    
    start_time = time.time()
    mrg, fit, history = fit_to_data(
        sim['y'][:n_fit, :g_fit],
        sim['m'][:n_fit, :g_fit],
        fit_kwargs=fit_kwargs,
        preclust=preclust,
        preclust_kwargs=preclust_kwargs,
        postclust=postclust,
        postclust_kwargs=postclust_kwargs,
        seed=seed,
    )
    end_time = time.time()
    
    s_mrg = mrg['gamma'].shape[0]
    
    sim_gamma_adj = mask_missing_genotype(sim['gamma'][:, :g_fit], sim['delta'][:, :g_fit])
    mrg_gamma_adj = mask_missing_genotype(mrg['gamma'], mrg['delta'])
    best_hit, best_dist = match_genotypes(sim_gamma_adj, mrg_gamma_adj)
    weighted_mean_genotype_error = (best_dist * sim['pi'][:n_fit].mean(0)).sum()
    runtime = end_time - start_time
    
    _, _, beta_diversity_error_ratio, _ = (
        community_accuracy_test(sim['pi'][:n_fit], mrg['pi'])
    )
    
    strain_count_error = s_mrg - s_sim
    
    mean_sample_weighted_genotype_entropy = (
        sample_mean_masked_genotype_entropy(mrg['pi'], mrg['gamma'], mrg['delta']).mean()
    )
    
    return (
        weighted_mean_genotype_error,
        beta_diversity_error_ratio,
        strain_count_error,
        mean_sample_weighted_genotype_entropy,
        runtime,
        sim,
        mrg
    )

## Imports

In [None]:
from sfacts import *

In [None]:
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import scipy as sp
import pyro
import pyro.distributions as dist
import torch
from functools import partial
from tqdm import tqdm
import xarray as xr
import warnings
from torch.jit import TracerWarning

In [None]:
warnings.filterwarnings(
    "ignore",
    message="torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.",
    category=TracerWarning,
#     module="trace_elbo",  # FIXME: What is the correct regex for module?
#     lineno=5,
)

## Model Specification

In [None]:
epsilon_hyper_alpha, epsilon_hyper_beta = 1.5, 1.5 / 0.01
plt.hist(pyro.sample('epsilon_hyper', dist.Beta(epsilon_hyper_alpha, epsilon_hyper_beta).expand([10000])).cpu().numpy(), bins=100)
None

In [None]:
shape_info(model, n=100, g=200, s=20)

## Simulation

### SimShape-1: Small study

In [None]:
seed = 1
pyro.util.set_rng_seed(seed)

n_sim = 100
g_sim = 5000
s_sim = 20

sim1 = simulate(
    condition_model(
        model,
        data=dict(
            alpha_hyper_mean=100.
        ),
        n=n_sim,
        g=g_sim,
        s=s_sim,
        gamma_hyper=0.01,
        delta_hyper_temp=0.01,
        delta_hyper_p=0.7,
        pi_hyper=0.5,
        rho_hyper=2.,
        mu_hyper_mean=2.,
        mu_hyper_scale=0.5,
        m_hyper_r=10.,
        alpha_hyper_scale=0.5,
        epsilon_hyper_alpha=1.5,
        epsilon_hyper_beta=1.5/0.01,
        device='cpu'
    )
)

## Visualization

In [None]:
n_plt = 100
g_plt = 200
s_plt = 20

In [None]:
plot_genotype(counts_to_p_estimate(sim1['y'][:n_plt, :g_plt], sim1['m'][:n_plt, :g_plt]), linkage_kw=dict(progress=True))

In [None]:
plot_community(sim1['pi'][:s_plt, :n_plt])

In [None]:
plot_genotype_similarity(counts_to_p_estimate(sim1['y'][:n_plt, :g_plt], sim1['m'][:n_plt, :g_plt]), linkage_kw=dict(progress=True))

In [None]:
plot_genotype(sim1['gamma'][:s_plt, :g_plt])

In [None]:
plot_missing(sim1['delta'][:s_plt, :g_plt])

In [None]:
plot_missing(sim1['nu'][:n_plt, :g_plt])

In [None]:
sns.clustermap(sim1['m'][:n_plt, :g_plt], norm=mpl.colors.SymLogNorm(linthresh=1))

In [None]:
plot_genotype_similarity(sim1['gamma'][:s_plt, :g_plt])

In [None]:
plt.hist(sim1['epsilon'], bins=50)
None

In [None]:
plt.hist(sim1['alpha'], bins=20)
None

## Estimation

### Initialization

In [None]:
g_fit = 1000  # sim1['y'].shape[1]
n_fit = sim1['y'].shape[0]

sim1_gamma_init, sim1_pi_init = initialize_parameters_by_clustering_samples(
    sim1['y'][:n_fit, :g_fit],
    sim1['m'][:n_fit, :g_fit],
    thresh=0.1,
    additional_strains_factor=0.,
    progress=True,
)

print(sim1_pi_init.shape)

In [None]:
plot_genotype(sim1_gamma_init[:s_plt, :g_plt])

In [None]:
plot_genotype_similarity(sim1_gamma_init)

In [None]:
plot_community(sim1_pi_init[:n_plt, :s_plt])

### Fitting

In [None]:
s_fit = sim1_gamma_init.shape[0]
initialize_params = dict(gamma=sim1_gamma_init, pi=sim1_pi_init)

sim1_fit1, history = estimate_parameters(
    model,
    data=dict(y=sim1['y'][:, :g_fit], m=sim1['m'][:, :g_fit]),
    n=n_fit,
    g=g_fit,
    s=s_fit,
    gamma_hyper=0.1,
    pi_hyper=1.0,
    rho_hyper=0.5,
    mu_hyper_mean=5,
    mu_hyper_scale=5.,
    m_hyper_r=10.,
    delta_hyper_temp=0.1,
    delta_hyper_p=0.9,
    alpha_hyper_hyper_mean=100.,
    alpha_hyper_hyper_scale=10.,
    alpha_hyper_scale=0.5,
    epsilon_hyper_alpha=1.5,
    epsilon_hyper_beta=1.5 / 0.01,
    initialize_params=initialize_params,
    device='cpu',
    lag=100,
    lr=1e-1,
)

### Merging Strains

In [None]:
sim1_fit1_gamma_merge, sim1_fit1_pi_merge, sim1_fit1_delta_merge  = merge_similar_genotypes(
    sim1_fit1['gamma'],
    sim1_fit1['pi'],
    delta=sim1_fit1['delta'],
    thresh=0.1,
)

# print(sim1_gamma_init.shape[0], sim1_fit1['gamma'].shape[0], sim1_fit1_gamma_merge.shape[0])
print(sim1_fit1['gamma'].shape[0], sim1_fit1_gamma_merge.shape[0])

## Evaluation

In [None]:
sim1_gamma_adjusted = mask_missing_genotype(sim1['gamma'][:, :g_fit], sim1['delta'][:, :g_fit])
sim1_fit1_gamma_adjusted = mask_missing_genotype(sim1_fit1['gamma'], sim1_fit1['delta'])

### Ground Truth

#### Visualization

In [None]:
plot_genotype_comparison(
    data=dict(
        true=sim1_gamma_adjusted[:, :g_plt],
#         fit=sim1_fit1['gamma'][:, :g_plt],
        adj=sim1_fit1_gamma_adjusted[:, :g_plt],
#         init=sim1_gamma_init,
#         merg=sim1_fit1_gamma_merge,
    ),
    linkage_kw=dict(progress=True),
)

In [None]:
plot_community_comparison(
    data=dict(
        true=sim1['pi'],
        fit=sim1_fit1['pi'],
#         init=sim1_pi_init,
#         merg=sim1_fit1_pi_merge,
    ),
)

In [None]:
plt.scatter(sim1['epsilon'], sim1_fit1['epsilon'])
plt.plot([0, 0.04], [0, 0.04])

In [None]:
plt.scatter(sim1['alpha'], sim1_fit1['alpha'])
plt.plot([0, 200], [0, 200])

In [None]:
sns.heatmap(sim1_fit1['delta'], vmin=0, vmax=1)

In [None]:
plt.scatter(sim1['mu'], sim1_fit1['mu'])
plt.plot([0, 40], [0, 40])

In [None]:
# TODO: Plot comparing genotype accuracy to true strain abundance
# colored by mean entropy of the estimated genotype masked by delta

#### Fit scores

In [None]:
plt.scatter(sim1_fit1['alpha'], sample_mean_masked_genotype_entropy(sim1_fit1['pi'], sim1_fit1['gamma'], sim1_fit1['delta']))
sample_mean_masked_genotype_entropy(sim1_fit1['pi'], sim1_fit1['gamma'], sim1_fit1['delta']).mean()

In [None]:
best_hit, best_dist = match_genotypes(sim1_gamma_adjusted[:, :g_fit], sim1_fit1_gamma_adjusted[:, :g_fit])

print('weighted_mean_distance:', (best_dist * sim1['pi'].mean(0)).sum())
plt.scatter((sim1['pi'] * sim1['mu'].reshape(-1, 1)).sum(0), best_dist)

In [None]:
bc_sim = 1 - pdist(sim1['pi'], metric='braycurtis')
bc_fit = 1 - pdist(sim1_fit1['pi'], metric='braycurtis')
plt.scatter(
    bc_sim,
    bc_fit,
    marker='.',
    alpha=0.2,
)

community_accuracy_test(sim1['pi'], sim1_fit1['pi'])

### No Ground Truth

#### Visualization

In [None]:
# Strains that are not representative of true haplotypes
# are high entropy (even after masking with delta)
# and have low estimated total coverage.

best_true_strain, best_true_strain_dist = match_genotypes(sim1_fit1_gamma_adjusted[:, :g_fit], sim1_gamma_adjusted[:, :g_fit])
best_true_strain_dist

plt.scatter((sim1_fit1['pi'] * sim1_fit1['mu'].reshape((-1, 1))).sum(0), best_true_strain_dist, c=mean_masked_genotype_entropy(sim1_fit1['gamma'], sim1_fit1['delta']))

In [None]:
plot_genotype(sim1_fit1_gamma_adjusted[:, :g_plt], linkage_kw=dict(progress=True))

In [None]:
plot_community(sim1_fit1['pi'])

#### Confidence Scores

In [None]:
plt.hist(mean_masked_genotype_entropy(sim1_fit1['gamma'][:, :g_plt], sim1_fit1['delta'][:, :g_plt]))
None

In [None]:
plt.hist(sim1_fit1['alpha'], bins=20)
None

In [None]:
plot_genotype(sim1_fit1['gamma'][mean_masked_genotype_entropy(sim1_fit1['gamma'], sim1_fit1['delta']) < 0.1, :g_fit])

## Experiments

### Experiment 1: Average and variation in accuracy

In [None]:
results = []
for seed in range(10):
    generr, comperr, scounter, entropy, runtime, sim, fit = simulate_fit_and_evaluate(
        s_sim=20,
        n_sim=100,
        g_sim=500,
        n_fit=100,
        g_fit=500,
        sim_kwargs=dict(
            data=dict(
                alpha_hyper_mean=100.
            ),
            gamma_hyper=0.01,
            delta_hyper_temp=0.01,
            delta_hyper_p=0.7,
            pi_hyper=0.5,
            rho_hyper=2.,
            mu_hyper_mean=1.,
            mu_hyper_scale=0.5,
            m_hyper_r=10.,
            alpha_hyper_scale=0.5,
            epsilon_hyper_alpha=1.5,
            epsilon_hyper_beta=1.5/0.01,
            device='cpu'
        ),
        preclust_kwargs=dict(
            thresh=0.1,
            additional_strains_factor=0.1,
            progress=False,
        ),
        fit_kwargs=dict(
            gamma_hyper=0.01,
            pi_hyper=1.0,
            rho_hyper=0.5,
            mu_hyper_mean=5,
            mu_hyper_scale=5.,
            m_hyper_r=10.,
            delta_hyper_temp=0.1,
            delta_hyper_p=0.9,
            alpha_hyper_hyper_mean=100.,
            alpha_hyper_hyper_scale=10.,
            alpha_hyper_scale=0.5,
            epsilon_hyper_alpha=1.5,
            epsilon_hyper_beta=1.5 / 0.01,
            device='cpu',
            lag=10,
            lr=1e-0,
            progress=False
        ),
        postclust_kwargs=dict(
            thresh=0.1,
        ),
        seed=seed,
        return_raw=True,
    )
    results.append((seed, generr, comperr, scounter, entropy, runtime))
    print(seed, generr, comperr, scounter, entropy, runtime, sep='\t')
         
results1 = pd.DataFrame(results, columns=['seed', 'generr', 'comperr', 'scounterr', 'entropy', 'runtime'])

In [None]:
fig, axs = plt.subplots(3, 2)

for stat, ax in zip(['generr', 'comperr', 'scounterr', 'entropy', 'runtime'], axs.flatten()):
    ax.hist(results1[stat])
    ax.set_title(stat)
fig.tight_layout()

### Experiment 2: Benefits of increasing sample data (preclust)

In [None]:
results = []
for n_fit in [20, 50, 100, 150, 200, 500]:
    replicates = 0
    for seed in [0, 1, 3, 4, 5]:
        generr, comperr, scounter, entropy, runtime, sim, fit = simulate_fit_and_evaluate(
            s_sim=20,
            n_sim=500,
            g_sim=500,
            n_fit=n_fit,
            g_fit=500,
            sim_kwargs=dict(
                data=dict(
                    alpha_hyper_mean=100.
                ),
                gamma_hyper=0.01,
                delta_hyper_temp=0.01,
                delta_hyper_p=0.7,
                pi_hyper=0.5,
                rho_hyper=2.,
                mu_hyper_mean=1.,
                mu_hyper_scale=0.5,
                m_hyper_r=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5/0.01,
                device='cpu'
            ),
            preclust_kwargs=dict(
                thresh=0.1,
                additional_strains_factor=0.1,
                progress=False,
            ),
            fit_kwargs=dict(
                gamma_hyper=0.01,
                pi_hyper=1.0,
                rho_hyper=0.5,
                mu_hyper_mean=5,
                mu_hyper_scale=5.,
                m_hyper_r=10.,
                delta_hyper_temp=0.1,
                delta_hyper_p=0.9,
                alpha_hyper_hyper_mean=100.,
                alpha_hyper_hyper_scale=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
                device='cpu',
                lag=10,
                lr=1e-0,
                progress=False
            ),
            postclust_kwargs=dict(
                thresh=0.1,
            ),
            seed=seed,
            return_raw=True,
        )
        results.append((n_fit, seed, generr, comperr, scounter, entropy, runtime))
        print(n_fit, seed, generr, comperr, scounter, entropy, runtime, sep='\t')
         
results2 = pd.DataFrame(results, columns=['n_fit', 'seed', 'generr', 'comperr', 'scounterr', 'entropy', 'runtime'])

In [None]:
fig, axs = plt.subplots(3, 2)

for stat, ax in zip(['generr', 'comperr', 'scounterr', 'entropy', 'runtime'], axs.flatten()):
    results2.set_index(['n_fit', 'seed'])[stat].unstack().plot(ax=ax)
    ax.set_title(stat)
    ax.legend_.set_visible(False)
    if stat == 'generr':
        ax.set_yscale('log')
fig.tight_layout()

### Experiment 3: Benefits of increasing sample data (no preclust, `s` known)

In [None]:
results = []
for n_fit in [20, 50, 100, 150, 200, 500]:
    replicates = 0
    for seed in [0, 1, 3, 4, 5]:
        generr, comperr, scounter, entropy, runtime, sim, fit = simulate_fit_and_evaluate(
            s_sim=20,
            n_sim=500,
            g_sim=500,
            n_fit=n_fit,
            g_fit=500,
            sim_kwargs=dict(
                data=dict(
                    alpha_hyper_mean=100.
                ),
                gamma_hyper=0.01,
                delta_hyper_temp=0.01,
                delta_hyper_p=0.7,
                pi_hyper=0.5,
                rho_hyper=2.,
                mu_hyper_mean=1.,
                mu_hyper_scale=0.5,
                m_hyper_r=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5/0.01,
                device='cpu'
            ),
            preclust=False,
            fit_kwargs=dict(
                s=20,
                gamma_hyper=0.01,
                pi_hyper=1.0,
                rho_hyper=0.5,
                mu_hyper_mean=5,
                mu_hyper_scale=5.,
                m_hyper_r=10.,
                delta_hyper_temp=0.1,
                delta_hyper_p=0.9,
                alpha_hyper_hyper_mean=100.,
                alpha_hyper_hyper_scale=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
                device='cpu',
                lag=10,
                lr=1e-0,
                progress=False
            ),
            postclust_kwargs=dict(
                thresh=0.1,
            ),
            seed=seed,
            return_raw=True,
        )
        results.append((n_fit, seed, generr, comperr, scounter, entropy, runtime))
        print(n_fit, seed, generr, comperr, scounter, entropy, runtime, sep='\t')
         
results3 = pd.DataFrame(results, columns=['n_fit', 'seed', 'generr', 'comperr', 'scounterr', 'entropy', 'runtime'])

In [None]:
fig, axs = plt.subplots(3, 2)

for stat, ax in zip(['generr', 'comperr', 'scounterr', 'entropy', 'runtime'], axs.flatten()):
    results3.set_index(['n_fit', 'seed'])[stat].unstack().plot(ax=ax)
    ax.set_title(stat)
    ax.legend_.set_visible(False)
    if stat == 'generr':
        ax.set_yscale('log')
fig.tight_layout()

### Experiment 4: Benefits of increasing depth

In [None]:
results = []
for mu_hyper_mean_sim in [0.5, 1., 2., 5., 10., 100.]:
    replicates = 0
    for seed in [0, 1, 3, 4, 5]:
        generr, comperr, scounter, entropy, runtime, sim, fit = simulate_fit_and_evaluate(
            s_sim=20,
            n_sim=100,
            g_sim=500,
            n_fit=100,
            g_fit=500,
            sim_kwargs=dict(
                data=dict(
                    alpha_hyper_mean=100.
                ),
                gamma_hyper=0.01,
                delta_hyper_temp=0.01,
                delta_hyper_p=0.7,
                pi_hyper=0.5,
                rho_hyper=2.,
                mu_hyper_mean=mu_hyper_mean_sim,
                mu_hyper_scale=0.5,
                m_hyper_r=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5/0.01,
                device='cpu'
            ),
            preclust_kwargs=dict(
                thresh=0.1,
                additional_strains_factor=0.1,
                progress=False,
            ),
            fit_kwargs=dict(
                gamma_hyper=0.01,
                pi_hyper=1.0,
                rho_hyper=0.5,
                mu_hyper_mean=5.,
                mu_hyper_scale=5.,
                m_hyper_r=10.,
                delta_hyper_temp=0.1,
                delta_hyper_p=0.9,
                alpha_hyper_hyper_mean=100.,
                alpha_hyper_hyper_scale=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
                device='cpu',
                lag=10,
                lr=1e-0,
                progress=False
            ),
            postclust_kwargs=dict(
                thresh=0.1,
            ),
            seed=seed,
            return_raw=True,
        )
        results.append((mu_hyper_mean_sim, seed, generr, comperr, scounter, entropy, runtime))
        print(mu_hyper_mean_sim, seed, generr, comperr, scounter, entropy, runtime, sep='\t')
         
results4 = pd.DataFrame(results, columns=['mu_hyper_mean_sim', 'seed', 'generr', 'comperr', 'scounterr', 'entropy', 'runtime'])

In [None]:
fig, axs = plt.subplots(3, 2)

for stat, ax in zip(['generr', 'comperr', 'scounterr', 'entropy', 'runtime'], axs.flatten()):
    results4.set_index(['mu_hyper_mean_sim', 'seed'])[stat].unstack().plot(ax=ax)
    ax.set_title(stat)
    ax.legend_.set_visible(False)
    ax.set_xscale('log')
    if stat == 'generr':
        ax.set_yscale('log')
fig.tight_layout()

### Experiment 5: Benefits of increasing genotype data

In [None]:
results = []
for g_fit in [100, 250, 500, 1000, 2000]:
    replicates = 0
    for seed in [2, 3, 4, 5, 7]:
        generr, comperr, scounter, entropy, runtime, sim, fit = simulate_fit_and_evaluate(
            s_sim=20,
            n_sim=100,
            g_sim=2000,
            n_fit=100,
            g_fit=g_fit,
            sim_kwargs=dict(
                data=dict(
                    alpha_hyper_mean=100.
                ),
                gamma_hyper=0.01,
                delta_hyper_temp=0.01,
                delta_hyper_p=0.7,
                pi_hyper=0.5,
                rho_hyper=2.,
                mu_hyper_mean=1.,
                mu_hyper_scale=0.5,
                m_hyper_r=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5/0.01,
                device='cpu'
            ),
            preclust_kwargs=dict(
                thresh=0.1,
                additional_strains_factor=0.1,
                progress=False,
            ),
            fit_kwargs=dict(
                gamma_hyper=0.01,
                pi_hyper=1.0,
                rho_hyper=0.5,
                mu_hyper_mean=5,
                mu_hyper_scale=5.,
                m_hyper_r=10.,
                delta_hyper_temp=0.1,
                delta_hyper_p=0.9,
                alpha_hyper_hyper_mean=100.,
                alpha_hyper_hyper_scale=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
                device='cpu',
                lag=10,
                lr=1e-0,
                progress=False
            ),
            postclust_kwargs=dict(
                thresh=0.1,
            ),
            seed=seed,
            return_raw=True,
        )
        results.append((g_fit, seed, generr, comperr, scounter, entropy, runtime))
        print(g_fit, seed, generr, comperr, scounter, entropy, runtime, sep='\t')
         
results5 = pd.DataFrame(results, columns=['g_fit', 'seed', 'generr', 'comperr', 'scounterr', 'entropy', 'runtime'])

In [None]:
fig, axs = plt.subplots(3, 2)

for stat, ax in zip(['generr', 'comperr', 'scounterr', 'entropy', 'runtime'], axs.flatten()):
    results5.set_index(['g_fit', 'seed'])[stat].unstack().plot(ax=ax)
    ax.set_title(stat)
    ax.legend_.set_visible(False)
fig.tight_layout()

### Experiment 6: Strain-number estimation

In [None]:
results = []
for s_fit in [5, 10, 15, 20, 25, 30, 50]:
    replicates = 0
    for seed in [0, 1, 3, 4, 5]:
        generr, comperr, scounter, entropy, runtime, sim, fit = simulate_fit_and_evaluate(
            s_sim=20,
            n_sim=100,
            g_sim=500,
            n_fit=100,
            g_fit=500,
            sim_kwargs=dict(
                data=dict(
                    alpha_hyper_mean=100.
                ),
                gamma_hyper=0.01,
                delta_hyper_temp=0.01,
                delta_hyper_p=0.7,
                pi_hyper=0.5,
                rho_hyper=2.,
                mu_hyper_mean=1.,
                mu_hyper_scale=0.5,
                m_hyper_r=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5/0.01,
                device='cpu'
            ),
            preclust=False,
            fit_kwargs=dict(
                s=s_fit,
                gamma_hyper=0.01,
                pi_hyper=1.0,
                rho_hyper=0.5,
                mu_hyper_mean=5,
                mu_hyper_scale=5.,
                m_hyper_r=10.,
                delta_hyper_temp=0.1,
                delta_hyper_p=0.9,
                alpha_hyper_hyper_mean=100.,
                alpha_hyper_hyper_scale=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
                device='cpu',
                lag=10,
                lr=1e-0,
                progress=False
            ),
            postclust_kwargs=dict(
                thresh=0.1,
            ),
            seed=seed,
            return_raw=True,
        )
        results.append((s_fit, seed, generr, comperr, scounter, entropy, runtime))
        print(s_fit, seed, generr, comperr, scounter, entropy, runtime, sep='\t')
         
results6 = pd.DataFrame(results, columns=['s_fit', 'seed', 'generr', 'comperr', 'scounterr', 'entropy', 'runtime'])

In [None]:
fig, axs = plt.subplots(3, 2)

for stat, ax in zip(['generr', 'comperr', 'scounterr', 'entropy', 'runtime'], axs.flatten()):
    results6.set_index(['s_fit', 'seed'])[stat].unstack().plot(ax=ax)
    ax.set_title(stat)
    ax.legend_.set_visible(False)
fig.tight_layout()

### Experiment 7: Effects of genotype fuzzyness

In [None]:
results = []
for gamma_hyper_fit in [1e-8, 1e-5, 1e-3, 1e-2, 5e-2, 1e-1, 5e-1]:
    replicates = 0
    for seed in [0, 1, 3, 4, 5]:
        generr, comperr, scounter, entropy, runtime, sim, fit = simulate_fit_and_evaluate(
            s_sim=20,
            n_sim=100,
            g_sim=500,
            n_fit=100,
            g_fit=500,
            sim_kwargs=dict(
                data=dict(
                    alpha_hyper_mean=100.
                ),
                gamma_hyper=0.01,
                delta_hyper_temp=0.01,
                delta_hyper_p=0.7,
                pi_hyper=0.5,
                rho_hyper=2.,
                mu_hyper_mean=1.,
                mu_hyper_scale=0.5,
                m_hyper_r=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5/0.01,
                device='cpu'
            ),
            preclust_kwargs=dict(
                thresh=0.1,
                additional_strains_factor=0.1,
                progress=False,
            ),
            fit_kwargs=dict(
                gamma_hyper=gamma_hyper_fit,
                pi_hyper=1.0,
                rho_hyper=0.5,
                mu_hyper_mean=5,
                mu_hyper_scale=5.,
                m_hyper_r=10.,
                delta_hyper_temp=0.1,
                delta_hyper_p=0.9,
                alpha_hyper_hyper_mean=100.,
                alpha_hyper_hyper_scale=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
                device='cpu',
                lag=10,
                lr=1e-0,
                progress=False
            ),
                postclust_kwargs=dict(
                    thresh=0.1,
                ),
            seed=seed,
            return_raw=True,
        )
        results.append((gamma_hyper_fit, seed, generr, comperr, scounter, entropy, runtime))
        print(gamma_hyper_fit, seed, generr, comperr, scounter, entropy, runtime, sep='\t')
         
results7 = pd.DataFrame(results, columns=['gamma_hyper_fit', 'seed', 'generr', 'comperr', 'scounterr', 'entropy', 'runtime'])

In [None]:
fig, axs = plt.subplots(3, 2)

for stat, ax in zip(['generr', 'comperr', 'scounterr', 'entropy', 'runtime'], axs.flatten()):
    results7.set_index(['gamma_hyper_fit', 'seed'])[stat].unstack().plot(ax=ax)
    ax.set_title(stat)
    ax.legend_.set_visible(False)
    ax.set_xscale('log')
fig.tight_layout()

### Experiment 8: Effects of diversity regularization

In [None]:
results = []
for rho_hyper_fit in [0.01, 0.05, 0.1, 0.25, 0.5, 1.0]:
    replicates = 0
    for seed in [0, 1, 3, 4, 5]:
        generr, comperr, scounter, entropy, runtime, sim, fit = simulate_fit_and_evaluate(
            s_sim=20,
            n_sim=100,
            g_sim=500,
            n_fit=100,
            g_fit=500,
            sim_kwargs=dict(
                data=dict(
                    alpha_hyper_mean=100.
                ),
                gamma_hyper=0.01,
                delta_hyper_temp=0.01,
                delta_hyper_p=0.7,
                pi_hyper=0.5,
                rho_hyper=2.,
                mu_hyper_mean=1.,
                mu_hyper_scale=0.5,
                m_hyper_r=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5/0.01,
                device='cpu'
            ),
            preclust=False,
            fit_kwargs=dict(
                s=30,
                gamma_hyper=0.01,
                pi_hyper=1.0,
                rho_hyper=rho_hyper_fit,
                mu_hyper_mean=5,
                mu_hyper_scale=5.,
                m_hyper_r=10.,
                delta_hyper_temp=0.1,
                delta_hyper_p=0.9,
                alpha_hyper_hyper_mean=100.,
                alpha_hyper_hyper_scale=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
                device='cpu',
                lag=10,
                lr=1e-0,
                progress=False
            ),
            postclust_kwargs=dict(
                thresh=0.1,
            ),
            seed=seed,
            return_raw=True,
        )
        results.append((rho_hyper_fit, seed, generr, comperr, scounter, entropy, runtime))
        print(rho_hyper_fit, seed, generr, comperr, scounter, entropy, runtime, sep='\t')
         
results8 = pd.DataFrame(results, columns=['rho_hyper_fit', 'seed', 'generr', 'comperr', 'scounterr', 'entropy', 'runtime'])

In [None]:
fig, axs = plt.subplots(3, 2)

for stat, ax in zip(['generr', 'comperr', 'scounterr', 'entropy', 'runtime'], axs.flatten()):
    results8.set_index(['rho_hyper_fit', 'seed'])[stat].unstack().plot(ax=ax)
    ax.set_title(stat)
    ax.legend_.set_visible(False)
    ax.set_xscale('log')
fig.tight_layout()

### Experiment 9: Effects of heterogeneity regularization

In [None]:
results = []
for pi_hyper_fit in [0.1, 0.25, 0.5, 1.0, 1.5]:
    replicates = 0
    for seed in [0, 1, 3, 4, 5]:
        generr, comperr, scounter, entropy, runtime, sim, fit = simulate_fit_and_evaluate(
            s_sim=20,
            n_sim=100,
            g_sim=500,
            n_fit=100,
            g_fit=500,
            sim_kwargs=dict(
                data=dict(
                    alpha_hyper_mean=100.
                ),
                gamma_hyper=0.01,
                delta_hyper_temp=0.01,
                delta_hyper_p=0.7,
                pi_hyper=0.5,
                rho_hyper=2.,
                mu_hyper_mean=1.,
                mu_hyper_scale=0.5,
                m_hyper_r=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5/0.01,
                device='cpu'
            ),
            preclust_kwargs=dict(
                thresh=0.1,
                additional_strains_factor=0.1,
                progress=False,
            ),
            fit_kwargs=dict(
                gamma_hyper=0.01,
                pi_hyper=pi_hyper_fit,
                rho_hyper=0.5,
                mu_hyper_mean=5,
                mu_hyper_scale=5.,
                m_hyper_r=10.,
                delta_hyper_temp=0.1,
                delta_hyper_p=0.9,
                alpha_hyper_hyper_mean=100.,
                alpha_hyper_hyper_scale=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
                device='cpu',
                lag=10,
                lr=1e-0,
                progress=False
            ),
            postclust_kwargs=dict(
                thresh=0.1,
            ),
            seed=seed,
            return_raw=True,
        )
        results.append((pi_hyper_fit, seed, generr, comperr, scounter, entropy, runtime))
        print(pi_hyper_fit, seed, generr, comperr, scounter, entropy, runtime, sep='\t')
         
results9 = pd.DataFrame(results, columns=['pi_hyper_fit', 'seed', 'generr', 'comperr', 'scounterr', 'entropy', 'runtime'])

In [None]:
fig, axs = plt.subplots(3, 2)

for stat, ax in zip(['generr', 'comperr', 'scounterr', 'entropy', 'runtime'], axs.flatten()):
    results9.set_index(['pi_hyper_fit', 'seed'])[stat].unstack().plot(ax=ax)
    ax.set_title(stat)
    ax.legend_.set_visible(False)
    ax.set_xscale('log')
fig.tight_layout()

### Experiment 10: Effects of preclustering threshold

In [None]:
results = []
for preclust_thresh in [0.03, 0.05, 0.08, 0.1, 0.12, 0.15, 0.2]:
    replicates = 0
    for seed in [0, 1, 3, 4, 5]:
        generr, comperr, scounter, entropy, runtime, sim, fit = simulate_fit_and_evaluate(
            s_sim=20,
            n_sim=100,
            g_sim=500,
            n_fit=100,
            g_fit=500,
            sim_kwargs=dict(
                data=dict(
                    alpha_hyper_mean=100.
                ),
                gamma_hyper=0.01,
                delta_hyper_temp=0.01,
                delta_hyper_p=0.7,
                pi_hyper=0.5,
                rho_hyper=2.,
                mu_hyper_mean=1.,
                mu_hyper_scale=0.5,
                m_hyper_r=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5/0.01,
                device='cpu'
            ),
            preclust_kwargs=dict(
                thresh=preclust_thresh,
                additional_strains_factor=0.1,
                progress=False,
            ),
            fit_kwargs=dict(
                gamma_hyper=0.01,
                pi_hyper=1.0,
                rho_hyper=0.5,
                mu_hyper_mean=5,
                mu_hyper_scale=5.,
                m_hyper_r=10.,
                delta_hyper_temp=0.1,
                delta_hyper_p=0.9,
                alpha_hyper_hyper_mean=100.,
                alpha_hyper_hyper_scale=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
                device='cpu',
                lag=10,
                lr=1e-0,
                progress=False
            ),
            postclust_kwargs=dict(
                thresh=0.1,
            ),
            seed=seed,
            return_raw=True,
        )
        results.append((preclust_thresh, seed, generr, comperr, scounter, entropy, runtime))
        print(preclust_thresh, seed, generr, comperr, scounter, entropy, runtime, sep='\t')
         
results10 = pd.DataFrame(results, columns=['preclust_thresh', 'seed', 'generr', 'comperr', 'scounterr', 'entropy', 'runtime'])

In [None]:
fig, axs = plt.subplots(3, 2)

for stat, ax in zip(['generr', 'comperr', 'scounterr', 'entropy', 'runtime'], axs.flatten()):
    results10.set_index(['preclust_thresh', 'seed'])[stat].unstack().plot(ax=ax)
    ax.set_title(stat)
    ax.legend_.set_visible(False)
    if stat == 'comperr':
        ax.set_yscale('log')
fig.tight_layout()

### Experiment 11: Effects of strain merging (postclustering) threshold

In [None]:
results = []
for postclust_thresh in [0.03, 0.05, 0.08, 0.1, 0.12, 0.15, 0.2]:
    replicates = 0
    for seed in [0, 1, 3, 4, 5]:
        generr, comperr, scounter, entropy, runtime, sim, fit = simulate_fit_and_evaluate(
            s_sim=20,
            n_sim=100,
            g_sim=500,
            n_fit=100,
            g_fit=500,
            sim_kwargs=dict(
                data=dict(
                    alpha_hyper_mean=100.
                ),
                gamma_hyper=0.01,
                delta_hyper_temp=0.01,
                delta_hyper_p=0.7,
                pi_hyper=0.5,
                rho_hyper=2.,
                mu_hyper_mean=1.,
                mu_hyper_scale=0.5,
                m_hyper_r=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
                device='cpu'
            ),
            preclust_kwargs=dict(
                thresh=0.1,
                additional_strains_factor=0.1,
                progress=False,
            ),
            fit_kwargs=dict(
                gamma_hyper=0.01,
                pi_hyper=1.0,
                rho_hyper=0.5,
                mu_hyper_mean=5,
                mu_hyper_scale=5.,
                m_hyper_r=10.,
                delta_hyper_temp=0.1,
                delta_hyper_p=0.9,
                alpha_hyper_hyper_mean=100.,
                alpha_hyper_hyper_scale=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
                device='cpu',
                lag=10,
                lr=1e-0,
                progress=False
            ),
            postclust_kwargs=dict(
                thresh=postclust_thresh,
            ),
            seed=seed,
            return_raw=True,
        )
        results.append((postclust_thresh, seed, generr, comperr, scounter, entropy, runtime))
        print(postclust_thresh, seed, generr, comperr, scounter, entropy, runtime, sep='\t')
         
results11 = pd.DataFrame(results, columns=['postclust_thresh', 'seed', 'generr', 'comperr', 'scounterr', 'entropy', 'runtime'])

In [None]:
fig, axs = plt.subplots(3, 2)

for stat, ax in zip(['generr', 'comperr', 'scounterr', 'entropy', 'runtime'], axs.flatten()):
    results11.set_index(['postclust_thresh', 'seed'])[stat].unstack().plot(ax=ax)
    ax.set_title(stat)
    ax.legend_.set_visible(False)
    if stat == 'comperr':
        ax.set_yscale('log')
fig.tight_layout()

### Visualize all

In [None]:
# TODO: Big matrix plot.

fig, axs = plt.subplots(5, 10, figsize=(2*11, 2*5), sharex='col', sharey='row')

for (stat, scale_y), row in zip([
    ('generr', 'log'),
    ('comperr', 'log'),
    ('scounterr', 'symlog'),
    ('entropy', 'linear'),
    ('runtime', 'log')
], axs):
    for (results, indexer, scale_x, title), ax in zip([
        (results2, 'n_fit', 'log', 2),
        (results3, 'n_fit', 'log', 3),
        (results4, 'mu_hyper_mean_sim', 'log', 4),
        (results5, 'g_fit', 'log', 5),
        (results6, 's_fit', 'linear', 6),
        (results7, 'gamma_hyper_fit', 'log', 7),
        (results8, 'rho_hyper_fit', 'log', 8),
        (results9, 'pi_hyper_fit', 'log', 9),
        (results10, 'preclust_thresh', 'linear', 10),
        (results11, 'postclust_thresh', 'linear', 11),
    ], row):
        results.set_index([indexer, 'seed'])[stat].unstack().plot(ax=ax)
        ax.set_ylabel(stat)
        ax.set_xlabel(indexer)
        ax.legend_.set_visible(False)
        ax.set_xscale(scale_x)
        ax.set_yscale(scale_y)
        ax.set_title(title)

fig.tight_layout()