## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sfacts as sf

In [None]:
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import scipy as sp
import pyro
import pyro.distributions as dist
import torch
from functools import partial
from tqdm import tqdm
import xarray as xr
import warnings
from torch.jit import TracerWarning

mpl.rcParams['figure.dpi']= 200

In [None]:
warnings.filterwarnings(
    "ignore",
    message="torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.",
    category=TracerWarning,
#     module="trace_elbo",  # FIXME: What is the correct regex for module?
#     lineno=5,
)

## Experiments

### Experiment 0: Average and variation in fitting accuracy

In [None]:
results = []
for seed_fit in range(10):
    for seed in [0, 1, 3, 4, 5]:
        generr, comperr, scounter, entropy, runtime, sim, fit = sf.workflow.simulate_fit_and_evaluate(
            s_sim=20,
            n_sim=100,
            g_sim=500,
            n_fit=100,
            g_fit=500,
            sim_kwargs=dict(
                data=dict(
                    alpha_hyper_mean=100.
                ),
                gamma_hyper=0.01,
                delta_hyper_temp=0.01,
                delta_hyper_p=0.7,
                pi_hyper=0.5,
                rho_hyper=2.,
                mu_hyper_mean=1.,
                mu_hyper_scale=0.5,
                m_hyper_r=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5/0.01,
                device='cpu'
            ),
            preclust_kwargs=dict(
                thresh=0.1,
                additional_strains_factor=0.1,
                progress=True,
            ),
            fit_kwargs=dict(
                gamma_hyper=0.01,
                pi_hyper=1.0,
                rho_hyper=0.5,
                mu_hyper_mean=5,
                mu_hyper_scale=5.,
                m_hyper_r=10.,
                delta_hyper_temp=0.1,
                delta_hyper_p=0.9,
                alpha_hyper_hyper_mean=100.,
                alpha_hyper_hyper_scale=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
                device='cpu',
                lag=10,
                lr=2e-0,
                progress=True
            ),
            postclust_kwargs=dict(
                thresh=0.1,
            ),
            seed_sim=seed,
            seed_fit=seed_fit,
            quiet=True,
        )
        results.append((seed_fit, seed, generr, comperr, scounter, entropy, runtime))
        print(seed_fit, seed, generr, comperr, scounter, entropy, runtime, sep='\t')
         
results0 = pd.DataFrame(results, columns=['seed_fit', 'seed', 'generr', 'comperr', 'scounterr', 'entropy', 'runtime'])

In [None]:
fig, axs = plt.subplots(3, 2)

for stat, ax in zip(['generr', 'comperr', 'scounterr', 'entropy', 'runtime'], axs.flatten()):
    results0.set_index(['seed_fit', 'seed'])[stat].unstack().plot(ax=ax)
    ax.set_title(stat)
    ax.legend_.set_visible(False)
fig.tight_layout()

### Experiment 1: Average and variation in accuracy

In [None]:
results = []
for seed in range(10):
    generr, comperr, scounter, entropy, runtime, sim, fit = sf.workflow.simulate_fit_and_evaluate(
        s_sim=20,
        n_sim=100,
        g_sim=500,
        n_fit=100,
        g_fit=500,
        sim_kwargs=dict(
            data=dict(
                alpha_hyper_mean=100.
            ),
            gamma_hyper=0.01,
            delta_hyper_temp=0.01,
            delta_hyper_p=0.7,
            pi_hyper=0.5,
            rho_hyper=2.,
            mu_hyper_mean=1.,
            mu_hyper_scale=0.5,
            m_hyper_r=10.,
            alpha_hyper_scale=0.5,
            epsilon_hyper_alpha=1.5,
            epsilon_hyper_beta=1.5/0.01,
            device='cpu'
        ),
        preclust_kwargs=dict(
            thresh=0.1,
            additional_strains_factor=0.1,
            progress=False,
        ),
        fit_kwargs=dict(
            gamma_hyper=0.01,
            pi_hyper=1.0,
            rho_hyper=0.5,
            mu_hyper_mean=5,
            mu_hyper_scale=5.,
            m_hyper_r=10.,
            delta_hyper_temp=0.1,
            delta_hyper_p=0.9,
            alpha_hyper_hyper_mean=100.,
            alpha_hyper_hyper_scale=10.,
            alpha_hyper_scale=0.5,
            epsilon_hyper_alpha=1.5,
            epsilon_hyper_beta=1.5 / 0.01,
            device='cpu',
            lag=10,
            lr=1e-0,
            progress=False
        ),
        postclust_kwargs=dict(
            thresh=0.1,
        ),
        seed_sim=seed,
        seed_fit=seed,
        quiet=True,
    )
    results.append((seed, generr, comperr, scounter, entropy, runtime))
    print(seed, generr, comperr, scounter, entropy, runtime, sep='\t')
         
results1 = pd.DataFrame(results, columns=['seed', 'generr', 'comperr', 'scounterr', 'entropy', 'runtime'])

In [None]:
fig, axs = plt.subplots(3, 2)

for stat, ax in zip(['generr', 'comperr', 'scounterr', 'entropy', 'runtime'], axs.flatten()):
    ax.hist(results1[stat])
    ax.set_title(stat)
fig.tight_layout()

### Experiment 2: Benefits of increasing sample data (preclust)

In [None]:
results = []
for n_fit in [20, 50, 100, 150, 200, 500]:
    replicates = 0
    for seed in [0, 1, 3, 4, 5]:
        generr, comperr, scounter, entropy, runtime, sim, fit = sf.workflow.simulate_fit_and_evaluate(
            s_sim=20,
            n_sim=500,
            g_sim=500,
            n_fit=n_fit,
            g_fit=500,
            sim_kwargs=dict(
                data=dict(
                    alpha_hyper_mean=100.
                ),
                gamma_hyper=0.01,
                delta_hyper_temp=0.01,
                delta_hyper_p=0.7,
                pi_hyper=0.5,
                rho_hyper=2.,
                mu_hyper_mean=1.,
                mu_hyper_scale=0.5,
                m_hyper_r=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5/0.01,
                device='cpu'
            ),
            preclust_kwargs=dict(
                thresh=0.1,
                additional_strains_factor=0.1,
                progress=False,
            ),
            fit_kwargs=dict(
                gamma_hyper=0.01,
                pi_hyper=1.0,
                rho_hyper=0.5,
                mu_hyper_mean=5,
                mu_hyper_scale=5.,
                m_hyper_r=10.,
                delta_hyper_temp=0.1,
                delta_hyper_p=0.9,
                alpha_hyper_hyper_mean=100.,
                alpha_hyper_hyper_scale=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
                device='cpu',
                lag=10,
                lr=1e-0,
                progress=False
            ),
            postclust_kwargs=dict(
                thresh=0.1,
            ),
            seed_sim=seed,
            seed_fit=seed,
            quiet=True,
        )
        results.append((n_fit, seed, generr, comperr, scounter, entropy, runtime))
        print(n_fit, seed, generr, comperr, scounter, entropy, runtime, sep='\t')
         
results2 = pd.DataFrame(results, columns=['n_fit', 'seed', 'generr', 'comperr', 'scounterr', 'entropy', 'runtime'])

In [None]:
fig, axs = plt.subplots(3, 2)

for stat, ax in zip(['generr', 'comperr', 'scounterr', 'entropy', 'runtime'], axs.flatten()):
    results2.set_index(['n_fit', 'seed'])[stat].unstack().plot(ax=ax)
    ax.set_title(stat)
    ax.legend_.set_visible(False)
    if stat == 'generr':
        ax.set_yscale('log')
fig.tight_layout()

### Experiment 3: Benefits of increasing sample data (no preclust, `s` known)

In [None]:
results = []
for n_fit in [20, 50, 100, 150, 200, 500]:
    replicates = 0
    for seed in [0, 1, 3, 4, 5]:
        generr, comperr, scounter, entropy, runtime, sim, fit = sf.workflow.simulate_fit_and_evaluate(
            s_sim=20,
            n_sim=500,
            g_sim=500,
            n_fit=n_fit,
            g_fit=500,
            sim_kwargs=dict(
                data=dict(
                    alpha_hyper_mean=100.
                ),
                gamma_hyper=0.01,
                delta_hyper_temp=0.01,
                delta_hyper_p=0.7,
                pi_hyper=0.5,
                rho_hyper=2.,
                mu_hyper_mean=1.,
                mu_hyper_scale=0.5,
                m_hyper_r=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5/0.01,
                device='cpu'
            ),
            preclust=False,
            fit_kwargs=dict(
                s=20,
                gamma_hyper=0.01,
                pi_hyper=1.0,
                rho_hyper=0.5,
                mu_hyper_mean=5,
                mu_hyper_scale=5.,
                m_hyper_r=10.,
                delta_hyper_temp=0.1,
                delta_hyper_p=0.9,
                alpha_hyper_hyper_mean=100.,
                alpha_hyper_hyper_scale=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
                device='cpu',
                lag=10,
                lr=1e-0,
                progress=False
            ),
            postclust_kwargs=dict(
                thresh=0.1,
            ),
            seed_sim=seed,
            seed_fit=seed,
            quiet=True,
        )
        results.append((n_fit, seed, generr, comperr, scounter, entropy, runtime))
        print(n_fit, seed, generr, comperr, scounter, entropy, runtime, sep='\t')
         
results3 = pd.DataFrame(results, columns=['n_fit', 'seed', 'generr', 'comperr', 'scounterr', 'entropy', 'runtime'])

In [None]:
fig, axs = plt.subplots(3, 2)

for stat, ax in zip(['generr', 'comperr', 'scounterr', 'entropy', 'runtime'], axs.flatten()):
    results3.set_index(['n_fit', 'seed'])[stat].unstack().plot(ax=ax)
    ax.set_title(stat)
    ax.legend_.set_visible(False)
    if stat == 'generr':
        ax.set_yscale('log')
fig.tight_layout()

### Experiment 4a: Benefits of increasing depth (no preclust)

In [None]:
results = []
for seed in [0, 1, 3, 4, 5]:
    for mu_hyper_mean_sim in reversed([0.5, 1., 2., 5., 10., 50., 1000.]):
        generr, comperr, scounter, entropy, runtime, sim, fit = sf.workflow.simulate_fit_and_evaluate(
            s_sim=20,
            n_sim=100,
            g_sim=500,
            n_fit=100,
            g_fit=500,
            sim_kwargs=dict(
                data=dict(
                    alpha_hyper_mean=100.
                ),
                gamma_hyper=0.01,
                delta_hyper_temp=0.01,
                delta_hyper_p=0.7,
                pi_hyper=0.5,
                rho_hyper=2.,
                mu_hyper_mean=mu_hyper_mean_sim,
                mu_hyper_scale=0.5,
                m_hyper_r=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5/0.01,
                device='cpu'
            ),
            preclust=False,
#             preclust_kwargs=dict(
#                 thresh=0.1,
#                 additional_strains_factor=0.1,
#                 progress=False,
#             ),
            fit_kwargs=dict(
                s=20,
                gamma_hyper=0.01,
                pi_hyper=1.0,
                rho_hyper=0.5,
                mu_hyper_mean=5.,
                mu_hyper_scale=5.,
                m_hyper_r=10.,
                delta_hyper_temp=0.1,
                delta_hyper_p=0.9,
                alpha_hyper_hyper_mean=100.,
                alpha_hyper_hyper_scale=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
                device='cpu',
                lag=10,
                lr=1e-1,
                progress=False
            ),
            postclust_kwargs=dict(
                thresh=0.1,
            ),
            seed_sim=seed,
            seed_fit=seed,
            quiet=True,
        )
        results.append((mu_hyper_mean_sim, seed, generr, comperr, scounter, entropy, runtime))
        print(mu_hyper_mean_sim, seed, generr, comperr, scounter, entropy, runtime, sep='\t')
         
results4a = pd.DataFrame(results, columns=['mu_hyper_mean_sim', 'seed', 'generr', 'comperr', 'scounterr', 'entropy', 'runtime'])

In [None]:
fig, axs = plt.subplots(3, 2)

for stat, ax in zip(['generr', 'comperr', 'scounterr', 'entropy', 'runtime'], axs.flatten()):
    results4a.set_index(['mu_hyper_mean_sim', 'seed'])[stat].unstack().plot(ax=ax)
    ax.set_title(stat)
    ax.legend_.set_visible(False)
    ax.set_xscale('log')
    if stat == 'generr':
        ax.set_yscale('logit')
        ax.set_ylim(1e-2, 5e-1)
fig.tight_layout()

### Experiment 4b: Effects increasing depth (with preclust)

In [None]:
results = []
for seed in [0, 1, 3, 4, 5]:
    for mu_hyper_mean_sim in reversed([0.5, 1., 2., 5., 10., 50., 1000.]):
        generr, comperr, scounter, entropy, runtime, sim, fit = sf.workflow.simulate_fit_and_evaluate(
            s_sim=20,
            n_sim=100,
            g_sim=500,
            n_fit=100,
            g_fit=500,
            sim_kwargs=dict(
                data=dict(
                    alpha_hyper_mean=100.
                ),
                gamma_hyper=0.01,
                delta_hyper_temp=0.01,
                delta_hyper_p=0.7,
                pi_hyper=0.5,
                rho_hyper=2.,
                mu_hyper_mean=mu_hyper_mean_sim,
                mu_hyper_scale=0.5,
                m_hyper_r=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5/0.01,
                device='cpu'
            ),
            preclust_kwargs=dict(
                thresh=0.1,
                additional_strains_factor=0.1,
                progress=False,
            ),
            fit_kwargs=dict(
                gamma_hyper=0.01,
                pi_hyper=1.0,
                rho_hyper=0.5,
                mu_hyper_mean=5.,
                mu_hyper_scale=5.,
                m_hyper_r=10.,
                delta_hyper_temp=0.1,
                delta_hyper_p=0.9,
                alpha_hyper_hyper_mean=100.,
                alpha_hyper_hyper_scale=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
                device='cpu',
                lag=10,
                lr=1e-1,
                progress=False
            ),
            postclust_kwargs=dict(
                thresh=0.1,
            ),
            seed_sim=seed,
            seed_fit=seed,
            quiet=True,
        )
        results.append((mu_hyper_mean_sim, seed, generr, comperr, scounter, entropy, runtime))
        print(mu_hyper_mean_sim, seed, generr, comperr, scounter, entropy, runtime, sep='\t')
         
results4b = pd.DataFrame(results, columns=['mu_hyper_mean_sim', 'seed', 'generr', 'comperr', 'scounterr', 'entropy', 'runtime'])

In [None]:
fig, axs = plt.subplots(3, 2)

for stat, ax in zip(['generr', 'comperr', 'scounterr', 'entropy', 'runtime'], axs.flatten()):
    results4b.set_index(['mu_hyper_mean_sim', 'seed'])[stat].unstack().plot(ax=ax)
    ax.set_title(stat)
    ax.legend_.set_visible(False)
    ax.set_xscale('log')
    if stat == 'generr':
        pass
#         ax.set_yscale('logit')
#         ax.set_ylim(1e-2, 5e-1)
fig.tight_layout()

### Experiment 5: Benefits of increasing genotype data

In [None]:
results = []
for g_fit in [100, 250, 500, 1000, 2000]:
    replicates = 0
    for seed in [0, 3, 4, 5, 8]:
        generr, comperr, scounter, entropy, runtime, sim, fit = sf.workflow.simulate_fit_and_evaluate(
            s_sim=20,
            n_sim=100,
            g_sim=2000,
            n_fit=100,
            g_fit=g_fit,
            sim_kwargs=dict(
                data=dict(
                    alpha_hyper_mean=100.
                ),
                gamma_hyper=0.01,
                delta_hyper_temp=0.01,
                delta_hyper_p=0.7,
                pi_hyper=0.5,
                rho_hyper=2.,
                mu_hyper_mean=1.,
                mu_hyper_scale=0.5,
                m_hyper_r=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5/0.01,
                device='cpu'
            ),
            preclust_kwargs=dict(
                thresh=0.1,
                additional_strains_factor=0.1,
                progress=False,
            ),
            fit_kwargs=dict(
                gamma_hyper=0.01,
                pi_hyper=1.0,
                rho_hyper=0.5,
                mu_hyper_mean=5,
                mu_hyper_scale=5.,
                m_hyper_r=10.,
                delta_hyper_temp=0.1,
                delta_hyper_p=0.9,
                alpha_hyper_hyper_mean=100.,
                alpha_hyper_hyper_scale=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
                device='cpu',
                lag=10,
                lr=1e-0,
                progress=False
            ),
            postclust_kwargs=dict(
                thresh=0.1,
            ),
            seed_sim=seed,
            seed_fit=seed,
            quiet=True,
        )
        results.append((g_fit, seed, generr, comperr, scounter, entropy, runtime))
        print(g_fit, seed, generr, comperr, scounter, entropy, runtime, sep='\t')
         
results5 = pd.DataFrame(results, columns=['g_fit', 'seed', 'generr', 'comperr', 'scounterr', 'entropy', 'runtime'])

In [None]:
fig, axs = plt.subplots(3, 2)

for stat, ax in zip(['generr', 'comperr', 'scounterr', 'entropy', 'runtime'], axs.flatten()):
    results5.set_index(['g_fit', 'seed'])[stat].unstack().plot(ax=ax)
    ax.set_title(stat)
    ax.legend_.set_visible(False)
fig.tight_layout()

### Experiment 6: Strain-number estimation

In [None]:
results = []
for s_fit in [5, 10, 15, 20, 25, 30, 50]:
    replicates = 0
    for seed in [0, 1, 3, 4, 5]:
        generr, comperr, scounter, entropy, runtime, sim, fit = sf.workflow.simulate_fit_and_evaluate(
            s_sim=20,
            n_sim=100,
            g_sim=500,
            n_fit=100,
            g_fit=500,
            sim_kwargs=dict(
                data=dict(
                    alpha_hyper_mean=100.
                ),
                gamma_hyper=0.01,
                delta_hyper_temp=0.01,
                delta_hyper_p=0.7,
                pi_hyper=0.5,
                rho_hyper=2.,
                mu_hyper_mean=1.,
                mu_hyper_scale=0.5,
                m_hyper_r=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5/0.01,
                device='cpu'
            ),
            preclust=False,
            fit_kwargs=dict(
                s=s_fit,
                gamma_hyper=0.01,
                pi_hyper=1.0,
                rho_hyper=0.5,
                mu_hyper_mean=5,
                mu_hyper_scale=5.,
                m_hyper_r=10.,
                delta_hyper_temp=0.1,
                delta_hyper_p=0.9,
                alpha_hyper_hyper_mean=100.,
                alpha_hyper_hyper_scale=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
                device='cpu',
                lag=10,
                lr=1e-0,
                progress=False
            ),
            postclust_kwargs=dict(
                thresh=0.1,
            ),
            seed_sim=seed,
            seed_fit=seed,
            quiet=True,
        )
        results.append((s_fit, seed, generr, comperr, scounter, entropy, runtime))
        print(s_fit, seed, generr, comperr, scounter, entropy, runtime, sep='\t')
         
results6 = pd.DataFrame(results, columns=['s_fit', 'seed', 'generr', 'comperr', 'scounterr', 'entropy', 'runtime'])

In [None]:
fig, axs = plt.subplots(3, 2)

for stat, ax in zip(['generr', 'comperr', 'scounterr', 'entropy', 'runtime'], axs.flatten()):
    results6.set_index(['s_fit', 'seed'])[stat].unstack().plot(ax=ax)
    ax.set_title(stat)
    ax.legend_.set_visible(False)
fig.tight_layout()

### Experiment 7: Effects of genotype fuzzyness

In [None]:
results = []
for gamma_hyper_fit in [1e-8, 1e-5, 1e-3, 1e-2, 5e-2, 1e-1, 5e-1]:
    replicates = 0
    for seed in [0, 1, 3, 4, 5]:
        generr, comperr, scounter, entropy, runtime, sim, fit = sf.workflow.simulate_fit_and_evaluate(
            s_sim=20,
            n_sim=100,
            g_sim=500,
            n_fit=100,
            g_fit=500,
            sim_kwargs=dict(
                data=dict(
                    alpha_hyper_mean=100.
                ),
                gamma_hyper=0.01,
                delta_hyper_temp=0.01,
                delta_hyper_p=0.7,
                pi_hyper=0.5,
                rho_hyper=2.,
                mu_hyper_mean=1.,
                mu_hyper_scale=0.5,
                m_hyper_r=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5/0.01,
                device='cpu'
            ),
            preclust_kwargs=dict(
                thresh=0.1,
                additional_strains_factor=0.1,
                progress=False,
            ),
            fit_kwargs=dict(
                gamma_hyper=gamma_hyper_fit,
                pi_hyper=1.0,
                rho_hyper=0.5,
                mu_hyper_mean=5,
                mu_hyper_scale=5.,
                m_hyper_r=10.,
                delta_hyper_temp=0.1,
                delta_hyper_p=0.9,
                alpha_hyper_hyper_mean=100.,
                alpha_hyper_hyper_scale=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
                device='cpu',
                lag=10,
                lr=1e-0,
                progress=False
            ),
                postclust_kwargs=dict(
                    thresh=0.1,
                ),
            seed_sim=seed,
            seed_fit=seed,
            quiet=True,
        )
        results.append((gamma_hyper_fit, seed, generr, comperr, scounter, entropy, runtime))
        print(gamma_hyper_fit, seed, generr, comperr, scounter, entropy, runtime, sep='\t')
         
results7 = pd.DataFrame(results, columns=['gamma_hyper_fit', 'seed', 'generr', 'comperr', 'scounterr', 'entropy', 'runtime'])

In [None]:
fig, axs = plt.subplots(3, 2)

for stat, ax in zip(['generr', 'comperr', 'scounterr', 'entropy', 'runtime'], axs.flatten()):
    results7.set_index(['gamma_hyper_fit', 'seed'])[stat].unstack().plot(ax=ax)
    ax.set_title(stat)
    ax.legend_.set_visible(False)
    ax.set_xscale('log')
fig.tight_layout()

### Experiment 8: Effects of diversity regularization

In [None]:
results = []
for rho_hyper_fit in [1e-10, 0.0001, 0.01, 0.05, 0.1, 0.25, 0.5, 1.0]:
    replicates = 0
    for seed in [0, 1, 3, 4, 5]:
        generr, comperr, scounter, entropy, runtime, sim, fit = sf.workflow.simulate_fit_and_evaluate(
            s_sim=20,
            n_sim=100,
            g_sim=500,
            n_fit=100,
            g_fit=500,
            sim_kwargs=dict(
                data=dict(
                    alpha_hyper_mean=100.
                ),
                gamma_hyper=0.01,
                delta_hyper_temp=0.01,
                delta_hyper_p=0.7,
                pi_hyper=0.5,
                rho_hyper=2.,
                mu_hyper_mean=1.,
                mu_hyper_scale=0.5,
                m_hyper_r=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5/0.01,
                device='cpu'
            ),
            preclust=False,
            fit_kwargs=dict(
                s=30,
                gamma_hyper=0.01,
                pi_hyper=1.0,
                rho_hyper=rho_hyper_fit,
                mu_hyper_mean=5,
                mu_hyper_scale=5.,
                m_hyper_r=10.,
                delta_hyper_temp=0.1,
                delta_hyper_p=0.9,
                alpha_hyper_hyper_mean=100.,
                alpha_hyper_hyper_scale=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
                device='cpu',
                lag=10,
                lr=1e-0,
                progress=False
            ),
            postclust_kwargs=dict(
                thresh=0.1,
            ),
            seed_sim=seed,
            seed_fit=seed,
            quiet=True,
        )
        results.append((rho_hyper_fit, seed, generr, comperr, scounter, entropy, runtime))
        print(rho_hyper_fit, seed, generr, comperr, scounter, entropy, runtime, sep='\t')
         
results8 = pd.DataFrame(results, columns=['rho_hyper_fit', 'seed', 'generr', 'comperr', 'scounterr', 'entropy', 'runtime'])

In [None]:
fig, axs = plt.subplots(3, 2)

for stat, ax in zip(['generr', 'comperr', 'scounterr', 'entropy', 'runtime'], axs.flatten()):
    results8.set_index(['rho_hyper_fit', 'seed'])[stat].unstack().plot(ax=ax)
    ax.set_title(stat)
    ax.legend_.set_visible(False)
    ax.set_xscale('log')
fig.tight_layout()

### Experiment 9: Effects of heterogeneity regularization

In [None]:
results = []
for seed in [0, 1, 3, 4, 5]:
    for pi_hyper_fit in [1e-4, 1e-3, 1e-2, 5e-1, 1e0, 1e1, 1e2]:
        generr, comperr, scounter, entropy, runtime, sim, fit = sf.workflow.simulate_fit_and_evaluate(
            s_sim=20,
            n_sim=100,
            g_sim=500,
            n_fit=100,
            g_fit=500,
            sim_kwargs=dict(
                data=dict(
                    alpha_hyper_mean=100.
                ),
                gamma_hyper=0.01,
                delta_hyper_temp=0.01,
                delta_hyper_p=0.7,
                pi_hyper=0.5,
                rho_hyper=2.,
                mu_hyper_mean=10.,
                mu_hyper_scale=0.5,
                m_hyper_r=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5/0.01,
                device='cpu'
            ),
            preclust_kwargs=dict(
                thresh=0.1,
                additional_strains_factor=0.1,
                progress=False,
            ),
            fit_kwargs=dict(
                gamma_hyper=0.01,
                pi_hyper=pi_hyper_fit,
                rho_hyper=0.5,
                mu_hyper_mean=5,
                mu_hyper_scale=5.,
                m_hyper_r=10.,
                delta_hyper_temp=0.1,
                delta_hyper_p=0.9,
                alpha_hyper_hyper_mean=100.,
                alpha_hyper_hyper_scale=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
                device='cpu',
                lag=10,
                lr=1e-0,
                progress=False
            ),
            postclust_kwargs=dict(
                thresh=0.1,
            ),
            seed_sim=seed,
            seed_fit=seed,
            quiet=True,
        )
        results.append((pi_hyper_fit, seed, generr, comperr, scounter, entropy, runtime))
        print(pi_hyper_fit, seed, generr, comperr, scounter, entropy, runtime, sep='\t')
         
results9 = pd.DataFrame(results, columns=['pi_hyper_fit', 'seed', 'generr', 'comperr', 'scounterr', 'entropy', 'runtime'])

In [None]:
fig, axs = plt.subplots(3, 2)

for stat, ax in zip(['generr', 'comperr', 'scounterr', 'entropy', 'runtime'], axs.flatten()):
    results9.set_index(['pi_hyper_fit', 'seed'])[stat].unstack().plot(ax=ax)
    ax.set_title(stat)
    ax.legend_.set_visible(False)
    ax.set_xscale('log')
fig.tight_layout()

### Experiment 10: Effects of preclustering threshold

In [None]:
results = []
for preclust_thresh in [0.03, 0.05, 0.08, 0.1, 0.12, 0.15, 0.2]:
    replicates = 0
    for seed in [0, 1, 3, 4, 6]:
        generr, comperr, scounter, entropy, runtime, sim, fit = sf.workflow.simulate_fit_and_evaluate(
            s_sim=20,
            n_sim=100,
            g_sim=500,
            n_fit=100,
            g_fit=500,
            sim_kwargs=dict(
                data=dict(
                    alpha_hyper_mean=100.
                ),
                gamma_hyper=0.01,
                delta_hyper_temp=0.01,
                delta_hyper_p=0.7,
                pi_hyper=0.5,
                rho_hyper=2.,
                mu_hyper_mean=1.,
                mu_hyper_scale=0.5,
                m_hyper_r=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5/0.01,
                device='cpu'
            ),
            preclust_kwargs=dict(
                thresh=preclust_thresh,
                additional_strains_factor=0.1,
                progress=False,
            ),
            fit_kwargs=dict(
                gamma_hyper=0.01,
                pi_hyper=1.0,
                rho_hyper=0.5,
                mu_hyper_mean=5,
                mu_hyper_scale=5.,
                m_hyper_r=10.,
                delta_hyper_temp=0.1,
                delta_hyper_p=0.9,
                alpha_hyper_hyper_mean=100.,
                alpha_hyper_hyper_scale=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
                device='cpu',
                lag=10,
                lr=1e-0,
                progress=False
            ),
            postclust_kwargs=dict(
                thresh=0.1,
            ),
            seed_sim=seed,
            seed_fit=seed,
            quiet=True,

        )
        results.append((preclust_thresh, seed, generr, comperr, scounter, entropy, runtime))
        print(preclust_thresh, seed, generr, comperr, scounter, entropy, runtime, sep='\t')
         
results10 = pd.DataFrame(results, columns=['preclust_thresh', 'seed', 'generr', 'comperr', 'scounterr', 'entropy', 'runtime'])

In [None]:
fig, axs = plt.subplots(3, 2)

for stat, ax in zip(['generr', 'comperr', 'scounterr', 'entropy', 'runtime'], axs.flatten()):
    results10.set_index(['preclust_thresh', 'seed'])[stat].unstack().plot(ax=ax)
    ax.set_title(stat)
    ax.legend_.set_visible(False)
    if stat == 'comperr':
        ax.set_yscale('log')
fig.tight_layout()

### Experiment 11: Effects of strain merging (postclustering) threshold

In [None]:
results = []
for postclust_thresh in [0.03, 0.05, 0.08, 0.1, 0.12, 0.15, 0.2]:
    replicates = 0
    for seed in [0, 1, 3, 4, 5]:
        generr, comperr, scounter, entropy, runtime, sim, fit = sf.workflow.simulate_fit_and_evaluate(
            s_sim=20,
            n_sim=100,
            g_sim=500,
            n_fit=100,
            g_fit=500,
            sim_kwargs=dict(
                data=dict(
                    alpha_hyper_mean=100.
                ),
                gamma_hyper=0.01,
                delta_hyper_temp=0.01,
                delta_hyper_p=0.7,
                pi_hyper=0.5,
                rho_hyper=2.,
                mu_hyper_mean=1.,
                mu_hyper_scale=0.5,
                m_hyper_r=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
                device='cpu'
            ),
            preclust_kwargs=dict(
                thresh=0.1,
                additional_strains_factor=0.1,
                progress=False,
            ),
            fit_kwargs=dict(
                gamma_hyper=0.01,
                pi_hyper=1.0,
                rho_hyper=0.5,
                mu_hyper_mean=5,
                mu_hyper_scale=5.,
                m_hyper_r=10.,
                delta_hyper_temp=0.1,
                delta_hyper_p=0.9,
                alpha_hyper_hyper_mean=100.,
                alpha_hyper_hyper_scale=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
                device='cpu',
                lag=10,
                lr=1e-0,
                progress=False
            ),
            postclust_kwargs=dict(
                thresh=postclust_thresh,
            ),
            seed_sim=seed,
            seed_fit=seed,
            quiet=True,
        )
        results.append((postclust_thresh, seed, generr, comperr, scounter, entropy, runtime))
        print(postclust_thresh, seed, generr, comperr, scounter, entropy, runtime, sep='\t')
         
results11 = pd.DataFrame(results, columns=['postclust_thresh', 'seed', 'generr', 'comperr', 'scounterr', 'entropy', 'runtime'])

In [None]:
fig, axs = plt.subplots(3, 2)

for stat, ax in zip(['generr', 'comperr', 'scounterr', 'entropy', 'runtime'], axs.flatten()):
    results11.set_index(['postclust_thresh', 'seed'])[stat].unstack().plot(ax=ax)
    ax.set_title(stat)
    ax.legend_.set_visible(False)
    if stat == 'comperr':
        ax.set_yscale('log')
fig.tight_layout()

### Experiment 12: Learning rate

In [None]:
results = []
for learning_rate in [0.05, 0.1, 0.5, 1., 1.5, 2.]:
    replicates = 0
    for seed in [0, 1, 3, 4, 5]:
        generr, comperr, scounter, entropy, runtime, sim, fit = sf.workflow.simulate_fit_and_evaluate(
            s_sim=20,
            n_sim=100,
            g_sim=500,
            n_fit=100,
            g_fit=500,
            sim_kwargs=dict(
                data=dict(
                    alpha_hyper_mean=100.
                ),
                gamma_hyper=0.01,
                delta_hyper_temp=0.01,
                delta_hyper_p=0.7,
                pi_hyper=0.5,
                rho_hyper=2.,
                mu_hyper_mean=1.,
                mu_hyper_scale=0.5,
                m_hyper_r=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5/0.01,
                device='cpu'
            ),
            preclust_kwargs=dict(
                thresh=0.1,
                additional_strains_factor=0.1,
                progress=False,
            ),
            fit_kwargs=dict(
                gamma_hyper=0.01,
                pi_hyper=1.0,
                rho_hyper=0.5,
                mu_hyper_mean=5,
                mu_hyper_scale=5.,
                m_hyper_r=10.,
                delta_hyper_temp=0.1,
                delta_hyper_p=0.9,
                alpha_hyper_hyper_mean=100.,
                alpha_hyper_hyper_scale=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
                device='cpu',
                lag=10,
                lr=learning_rate,
                progress=False
            ),
            postclust_kwargs=dict(
                thresh=0.1,
            ),
            seed_sim=seed,
            seed_fit=seed,
            quiet=True,
        )
        results.append((learning_rate, seed, generr, comperr, scounter, entropy, runtime))
        print(learning_rate, seed, generr, comperr, scounter, entropy, runtime, sep='\t')
         
results12 = pd.DataFrame(results, columns=['learning_rate', 'seed', 'generr', 'comperr', 'scounterr', 'entropy', 'runtime'])

In [None]:
fig, axs = plt.subplots(3, 2)

for stat, ax in zip(['generr', 'comperr', 'scounterr', 'entropy', 'runtime'], axs.flatten()):
    results12.set_index(['learning_rate', 'seed'])[stat].unstack().plot(ax=ax)
    ax.set_title(stat)
    ax.legend_.set_visible(False)
    ax.set_xscale('log')
fig.tight_layout()

### Experiment 13: Missingness

In [None]:
results = []
for delta_hyper_p_fit in [0.25, 0.5, 0.75, 0.9, 0.99, 1.]:
    replicates = 0
    for seed in [0, 1, 3, 4, 5]:
        generr, comperr, scounter, entropy, runtime, sim, fit = sf.workflow.simulate_fit_and_evaluate(
            s_sim=20,
            n_sim=100,
            g_sim=500,
            n_fit=100,
            g_fit=500,
            sim_kwargs=dict(
                data=dict(
                    alpha_hyper_mean=100.
                ),
                gamma_hyper=0.01,
                delta_hyper_temp=0.01,
                delta_hyper_p=0.7,
                pi_hyper=0.5,
                rho_hyper=2.,
                mu_hyper_mean=1.,
                mu_hyper_scale=0.5,
                m_hyper_r=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5/0.01,
                device='cpu'
            ),
            preclust_kwargs=dict(
                thresh=0.1,
                additional_strains_factor=0.1,
                progress=False,
            ),
            fit_kwargs=dict(
                gamma_hyper=0.01,
                pi_hyper=1.0,
                rho_hyper=0.5,
                mu_hyper_mean=5,
                mu_hyper_scale=5.,
                m_hyper_r=10.,
                delta_hyper_temp=0.1,
                delta_hyper_p=delta_hyper_p_fit,
                alpha_hyper_hyper_mean=100.,
                alpha_hyper_hyper_scale=10.,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
                device='cpu',
                lag=10,
                lr=2e-0,
                progress=False
            ),
            postclust_kwargs=dict(
                thresh=0.1,
            ),
            seed_sim=seed,
            seed_fit=seed,
            quiet=True,
        )
        results.append((delta_hyper_p_fit, seed, generr, comperr, scounter, entropy, runtime))
        print(delta_hyper_p_fit, seed, generr, comperr, scounter, entropy, runtime, sep='\t')
         
results13 = pd.DataFrame(results, columns=['delta_hyper_p_fit', 'seed', 'generr', 'comperr', 'scounterr', 'entropy', 'runtime'])

In [None]:
fig, axs = plt.subplots(3, 2)

for stat, ax in zip(['generr', 'comperr', 'scounterr', 'entropy', 'runtime'], axs.flatten()):
    results13.set_index(['delta_hyper_p_fit', 'seed'])[stat].unstack().plot(ax=ax)
    ax.set_title(stat)
    ax.legend_.set_visible(False)
    ax.set_xscale('log')
fig.tight_layout()

### Visualize all

In [None]:
# TODO: Big matrix plot.

all_results = [
    (results2, 'n_fit', 'log', 2),
    (results5, 'g_fit', 'log', 5),
#     (results6, 's_fit', 'linear', 6),
    (results4, 'mu_hyper_mean_sim', 'log', 4),
]

all_stats = [
    ('generr', 'log'),
    ('comperr', 'log'),
    ('scounterr', 'linear'),
    ('entropy', 'linear'),
    ('runtime', 'log')
]

nres = len(all_results)
nstat = len(all_stats)

fig, axs = plt.subplots(nstat, nres, figsize=(3 * nres, 2 * nstat), sharex='col', sharey='row')

for (stat, scale_y), row in zip(all_stats, axs):
    for (results, indexer, scale_x, title), ax in zip(all_results, row):
        results.set_index([indexer, 'seed'])[stat].unstack().plot(ax=ax)
        ax.set_ylabel(stat)
        ax.set_xlabel(indexer)
        ax.legend_.set_visible(False)
        ax.set_xscale(scale_x)
        ax.set_yscale(scale_y)
#        ax.set_title(title)

fig.tight_layout()
#ax.legend(bbox_to_anchor=(1, 1), title='replicate')

In [None]:
# TODO: Big matrix plot.

all_results = [
    (results2, 'n_fit', 'linear', 2),
    (results5, 'g_fit', 'linear', 5),
#     (results6, 's_fit', 'linear', 6),
    (results4, 'mu_hyper_mean_sim', 'linear', 4),
]

all_stats = [
    ('generr', 'linear'),
    ('comperr', 'linear'),
    ('scounterr', 'linear'),
    ('entropy', 'linear'),
    ('runtime', 'linear')
]

nres = len(all_results)
nstat = len(all_stats)

fig, axs = plt.subplots(nstat, nres, figsize=(3 * nres, 2 * nstat), sharex='col', sharey='row')

for (stat, scale_y), row in zip(all_stats, axs):
    for (results, indexer, scale_x, title), ax in zip(all_results, row):
        results.set_index([indexer, 'seed'])[stat].unstack().plot(ax=ax)
        ax.set_ylabel(stat)
        ax.set_xlabel(indexer)
        ax.legend_.set_visible(False)
        ax.set_xscale(scale_x)
        ax.set_yscale(scale_y)
#        ax.set_title(title)

fig.tight_layout()
ax.legend(bbox_to_anchor=(1, 1), title='replicate')

In [None]:
# TODO: Big matrix plot.

all_results = [
    (results0, 'seed_fit', 'linear', 0),
    (results2, 'n_fit', 'log', 2),
    (results3, 'n_fit', 'log', 3),
    (results4, 'mu_hyper_mean_sim', 'log', 4),
    (results5, 'g_fit', 'log', 5),
    (results6, 's_fit', 'linear', 6),
    (results7, 'gamma_hyper_fit', 'log', 7),
    (results8, 'rho_hyper_fit', 'log', 8),
    (results9, 'pi_hyper_fit', 'log', 9),
    (results10, 'preclust_thresh', 'linear', 10),
    (results11, 'postclust_thresh', 'linear', 11),
    (results12, 'learning_rate', 'log', 12),
    (results13, 'delta_hyper_p_fit', 'logit', 13),
]

all_stats = [
    ('generr', 'log'),
    ('comperr', 'log'),
    ('scounterr', 'symlog'),
    ('entropy', 'linear'),
    ('runtime', 'log')
]

nres = len(all_results)
nstat = len*(all_stats)

fig, axs = plt.subplots(nstat, nres, figsize=(2 * nres, 2 * nstat), sharex='col', sharey='row')

for (stat, scale_y), row in zip(all_stats, axs):
    for (results, indexer, scale_x, title), ax in zip(all_results, row):
        results.set_index([indexer, 'seed'])[stat].unstack().plot(ax=ax)
        ax.set_ylabel(stat)
        ax.set_xlabel(indexer)
        ax.legend_.set_visible(False)
        ax.set_xscale(scale_x)
        ax.set_yscale(scale_y)
        ax.set_title(title)

fig.tight_layout()

## Demonstrations

### Demo: Good accuracy with realistic conditions

In [None]:
results = []
for seed in range(20):
    generr, comperr, scounter, entropy, runtime, sim, fit = sf.workflow.simulate_fit_and_evaluate(
        s_sim=20,
        n_sim=200,
        g_sim=1000,
        n_fit=200,
        g_fit=1000,
        sim_kwargs=dict(
            data=dict(
                alpha_hyper_mean=100.
            ),
            gamma_hyper=0.01,
            delta_hyper_temp=0.01,
            delta_hyper_p=0.9,
            pi_hyper=0.5,
            rho_hyper=2.,
            mu_hyper_mean=10.,
            mu_hyper_scale=0.5,
            m_hyper_r=10.,
            alpha_hyper_scale=0.5,
            epsilon_hyper_alpha=1.5,
            epsilon_hyper_beta=1.5/0.01,
            device='cpu'
        ),
        preclust_kwargs=dict(
            thresh=0.1,
            additional_strains_factor=0.1,
            progress=False,
        ),
        fit_kwargs=dict(
            gamma_hyper=0.01,
            pi_hyper=1.0,
            rho_hyper=0.5,
            mu_hyper_mean=5,
            mu_hyper_scale=5.,
            m_hyper_r=10.,
            delta_hyper_temp=0.1,
            delta_hyper_p=0.9,
            alpha_hyper_hyper_mean=100.,
            alpha_hyper_hyper_scale=10.,
            alpha_hyper_scale=0.5,
            epsilon_hyper_alpha=1.5,
            epsilon_hyper_beta=1.5 / 0.01,
            device='cuda',
            lag=10,
            lr=2e-0,
            progress=True
        ),
        postclust_kwargs=dict(
            thresh=0.1,
        ),
        seed_sim=seed,
        seed_fit=seed,
        quiet=True,
    )
    results.append((seed, generr, comperr, scounter, entropy, runtime))
    print(seed, generr, comperr, scounter, entropy, runtime, sep='\t')
         
results_d0 = pd.DataFrame(results, columns=['seed', 'generr', 'comperr', 'scounterr', 'entropy', 'runtime'])

In [None]:
fig, axs = plt.subplots(3, 2)

for stat, ax in zip(['generr', 'comperr', 'scounterr', 'entropy', 'runtime'], axs.flatten()):
    ax.hist(results_d0[stat], bins=5)
    ax.set_title(stat)
fig.tight_layout()