In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sfacts as sf
import pyro
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
from functools import partial
import xarray as xr
import warnings
import torch

mpl.rcParams['figure.dpi'] = 70

def min_max_normalize(x):
    return (x - x.min()) / (x.max() - x.min())

In [None]:
import pandas as pd

In [None]:
import scipy as sp

In [None]:
warnings.filterwarnings(
    "ignore",
    message="torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.",
    category=torch.jit.TracerWarning,
#     module="trace_elbo",  # FIXME: What is the correct regex for module?
#     lineno=5,
)

## Plot Real Data

### 100022

In [None]:
# Sanity check on sfacts/data.py
obs = (
    sf.data.Metagenotypes.load('data/ucfmt.sp-100022.gtpro-pileup.nc')
    .select_variable_positions(incid_thresh=0.2)
    .select_samples_with_coverage(0.1)
    .to_world()
)

obs.metagenotypes.to_estimated_genotypes().validate_constraints()

print(obs.sizes)
sf.plot.plot_metagenotype(
    obs.isel(position=range(500)),
)

### 102506

In [None]:
# Sanity check on sfacts/data.py
obs = (
    sf.data.Metagenotypes.load('data/ucfmt.sp-102506.gtpro-pileup.nc')
    .select_variable_positions(incid_thresh=0.2)
    .select_samples_with_coverage(0.1)
    .to_world()
)

obs.metagenotypes.to_estimated_genotypes().validate_constraints()

print(obs.sizes)
sf.plot.plot_metagenotype(
    obs.isel(position=range(1000)),
)

## Simulate Real-looking Data

In [None]:
sim_model = sf.model.ParameterizedModel(
    sf.model_zoo.hybrid_fuzzy_missing_dp_betabinomial_metagenotype,
    coords=dict(
        sample=100,
        position=500,
        strain=20,
        allele=['alt', 'ref'],
    ),
    data=dict(
        m_hyper_r_mean=4.,
#         alpha=100 * np.ones(100),
#         epsilon=0.05 * np.ones(100),
#         alpha=10000 * np.ones(100),
#         epsilon=0.000001 * np.ones(100),
    ),
    hyperparameters=dict(
        gamma_hyper=0.001,
        delta_hyper_r=0.85,
        delta_hyper_temp=0.001,
        rho_hyper=3.,
        pi_hyper=0.2,
        alpha_hyper_hyper_mean=200.0,
        alpha_hyper_hyper_scale=1.0,
        alpha_hyper_scale=1.0,
        epsilon_hyper_alpha=1.5,
        epsilon_hyper_beta=1.5 / 0.01,
        mu_hyper_mean=10.0,
        mu_hyper_scale=1.5,
#         m_hyper_r_mu=5,
        m_hyper_r_scale=1,
        
    )
)
# print(sim_model.data, sim_model.hyperparameters)

sim = sim_model.simulate_world(seed=2)

In [None]:
g = sf.plot.plot_metagenotype(
    sim,
#     row_linkage_func=lambda w: w.metagenotypes.linkage(dim='position'),
#     col_linkage_func=lambda w: w.metagenotypes.linkage(dim='strain'),
#     metric='euclidean',
    col_colors_func=lambda w: xr.Dataset(dict(
        mu=w.data.mu.pipe(np.sqrt),
        alpha=w.data.alpha.pipe(np.sqrt),
        m_hyper_r=w.data.m_hyper_r.pipe(np.cbrt),
        max_frac=w.communities.max('strain').rename('max_frac'),
#         max_strain=w.communities.to_pandas().idxmax(1),
    )),
#     row_col_annotation_cmap=mpl.cm.rainbow,
)
# sf.plot.plot_genotype(sim, scalex=0.6, scaley=0.02, cwidth=0., cheight=0.1, dwidth=0.2, dheight=1.0)
# sf.plot._calculate_clustermap_sizes(10, 10, scalex=0.6, scaley=0.02, cwidth=0., cheight=0.1, dwidth=0.2, dheight=1.0)
# sf.plot.plot_genotype(sf.data.Metagenotypes.from_counts_and_totals(sim0.data['y'], sim0.data['m']))

In [None]:
g = sf.plot.plot_metagenotype(
    sim,
#     row_linkage_func=lambda w: w.metagenotypes.linkage(dim='position'),
#     col_linkage_func=lambda w: w.metagenotypes.linkage(dim='strain'),
#     metric='euclidean',
    col_colors_func=lambda w: xr.Dataset(dict(
        max_strain=w.communities.to_pandas().idxmax(1),
    )),
    row_col_annotation_cmap=mpl.cm.tab20,
)
# sf.plot.plot_genotype(sim, scalex=0.6, scaley=0.02, cwidth=0., cheight=0.1, dwidth=0.2, dheight=1.0)
# sf.plot._calculate_clustermap_sizes(10, 10, scalex=0.6, scaley=0.02, cwidth=0., cheight=0.1, dwidth=0.2, dheight=1.0)
# sf.plot.plot_genotype(sf.data.Metagenotypes.from_counts_and_totals(sim0.data['y'], sim0.data['m']))

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))

sf.plot.ordination_plot(
    sim,
    dmat_func=(
        lambda w:
        pd.DataFrame(
            sp.spatial.distance.squareform(
                sp.spatial.distance.pdist(
                    w.metagenotypes.data.to_dataframe().squeeze().unstack('sample').T,
                    'cosine',
                )
            ),
            index=w.sample, columns=w.sample
        )
    ),
    vmin=0,
#     sizes_func=lambda w: w.data.mu.pipe(np.sqrt) * 10,
#     colors_func=lambda w: w.communities.max('strain'),
    sizes_func=lambda w: w.communities.max('strain')**2 * 75,
    colors_func=lambda w: w.communities.to_pandas().idxmax(1),
#     colors_func=lambda w: w.data.alpha.pipe(np.sqrt),
    cmap=mpl.cm.tab20,
    ax=ax,
)
None

In [None]:
sf.plot.plot_fuzzed_genotype(sim)
sf.plot.plot_missing(sim)

In [None]:
sf.plot.plot_community(
    sim,
    row_colors_func=lambda w: w.data[['mu', 'alpha']].pipe(np.sqrt)
)

## Fit Simulated Data

In [None]:
bins = np.linspace(0.5, 1., num=21)

sample = [7]

d = sim.sel(sample=sample)
plt.hist(d.metagenotypes.dominant_allele_fraction().values.T, bins=bins)
for freq in d.communities.values.squeeze():
    plt.axvline(1 - freq, color='k', lw=1, linestyle='--')
plt.xlim(0.5, 1.0)

In [None]:
d = sim

model_fit = (
    sf.model.ParameterizedModel(
        sf.model_zoo.hybrid_fuzzy_missing_dp_betabinomial_metagenotype,
        coords=dict(
            sample=d.sample.values,
            position=d.position.values,
            allele=d.allele.values,
            strain=range(30),
        ),
        hyperparameters=dict(
            gamma_hyper=0.01,
            delta_hyper_r=0.8,
            delta_hyper_temp=0.1,
            rho_hyper=0.01,
            pi_hyper=0.5,
            alpha_hyper_hyper_mean=200.0,
            alpha_hyper_hyper_scale=1.0,
            alpha_hyper_scale=0.5,
            epsilon_hyper_alpha=1.5,
            epsilon_hyper_beta=1.5 / 0.01,
        ),
    )
)

est1, history = sf.workflow.three_stage_fitting(
    model_fit.condition(
        **d.metagenotypes.to_counts_and_totals()
    ),
    stage2_hyperparameters=dict(gamma_hyper=1.0),
    lagA=20,
    lagB=200,
    opt=pyro.optim.Adamax({"lr": 1e-0}, {"clip_norm": 100}),
    seed=1,
)

sf.plot.plot_loss_history(history)

In [None]:
print(sf.evaluation.weighted_genotype_error(sim, est1), sf.evaluation.community_error(sim, est1))

In [None]:
sf.plot.plot_community(
    sf.data.World.concat(
        {
            'sim': sim,
            'est': est1
        },
        dim='strain'),
    col_colors_func=lambda w: xr.Dataset(dict(
        abundance=w.communities.mean("sample"),
        entropy=w.genotypes.entropy,
        which_fit=w.data['_concat_from'].to_series().map({'sim': 1, 'est': 0}).to_xarray(),
    )),
    norm=None,
)

In [None]:
sf.plot.plot_fuzzed_genotype(
    sf.data.World.concat(
        {
            'sim': sim,
            'est': est1
        },
        dim='strain'),
    col_colors_func=lambda w: xr.Dataset(dict(
        abundance=w.communities.mean("sample"),
        entropy=w.genotypes.entropy,
        which_fit=w.data['_concat_from'].to_series().map({'sim': 0, 'est': 1}).to_xarray(),
    )),
)

In [None]:
sf.plot.plot_missing(
    sf.data.World.concat(
        {
            'sim': sim,
            'est': est1
        },
        dim='strain'),
    col_colors_func=lambda w: xr.Dataset(dict(
        abundance=w.communities.mean("sample"),
        entropy=w.genotypes.entropy,
        which_fit=w.data['_concat_from'].to_series().map({'sim': 0, 'est': 1}).to_xarray(),
    )),
)

In [None]:
plt.scatter(sim.data.mu, est1.data.mu, c=sim.metagenotypes.sum('allele').mean('position'))

In [None]:
plt.scatter(sim.data.m_hyper_r, est1.data.m_hyper_r)

In [None]:
plt.scatter(sim.data.epsilon, est1.data.epsilon, c=est1.data.mu, alpha=0.7)

## Fit Real Data

In [None]:
# Sanity check on sfacts/data.py
obs = (
    sf.data.Metagenotypes.load('data/ucfmt.sp-100022.gtpro-pileup.nc')
    .select_variable_positions(incid_thresh=0.2)
    .select_samples_with_coverage(0.1)
    .to_world()
)

obs.metagenotypes.to_estimated_genotypes().validate_constraints()

print(obs.sizes)
sf.plot.plot_metagenotype(
    (
        obs
#         .isel(position=range(1000))
    ),
    col_colors_func=(
        lambda w: (
            w
            .metagenotypes
            .sum('allele')
            .mean('position')
            .pipe(np.sqrt)
            .rename('mean_depth')
        )
    ),
)

In [None]:
# Sanity check on sfacts/data.py
sf.plot.plot_depth(
    (
        obs
#         .isel(position=range(1000))
    ),
    col_colors_func=(
        lambda w: (
            w
            .metagenotypes
            .sum('allele')
            .mean('position')
            .pipe(np.sqrt)
            .rename('mean_depth')
        )
    ),
)

In [None]:
d = obs#.isel(position=range(500))

model_fit2 = (
    sf.model.ParameterizedModel(
        sf.model_zoo.hybrid_fuzzy_missing_dp_betabinomial_metagenotype,
        coords=dict(
            sample=d.sample.values,
            position=d.position.values,
            allele=d.allele.values,
            strain=range(30),
        ),
        hyperparameters=dict(
            gamma_hyper=0.01,
            delta_hyper_r=0.8,
            delta_hyper_temp=0.1,
            rho_hyper=0.01,
            pi_hyper=0.5,
            alpha_hyper_hyper_mean=200.0,
            alpha_hyper_hyper_scale=1.0,
            alpha_hyper_scale=0.5,
            epsilon_hyper_alpha=1.5,
            epsilon_hyper_beta=1.5 / 0.01,
        ),
    )
)

est2, history = sf.workflow.three_stage_fitting(
    model_fit2.condition(
        **d.metagenotypes.to_counts_and_totals()
    ),
    stage2_hyperparameters=dict(gamma_hyper=1.0),
    lagA=20,
    lagB=200,
    opt=pyro.optim.Adamax({"lr": 1e-0}, {"clip_norm": 100}),
    seed=1,
)

sf.plot.plot_loss_history(history)

In [None]:
plt.scatter(est2.metagenotypes.sum('allele').mean('position'), est2.data.mu)

In [None]:
d = est2

sf.plot.plot_community(
    d,
    col_colors_func=lambda w: xr.Dataset(dict(
        abundance=w.communities.mean("sample"),
        entropy=w.genotypes.entropy,
        missing=1 - w.missingness.mean("position"),
#         which_fit=w.data['_concat_from'].to_series().map({'est': 1}).to_xarray(),
    )),
    row_colors_func=lambda w: xr.Dataset(dict(
        mu=w.data.mu.pipe(np.sqrt),
        alpha=w.data.alpha.pipe(np.sqrt),
        max_frac=w.communities.max('strain').rename('max_frac'),
    )),
    norm=None,
)

In [None]:
bins = np.linspace(0.5, 1., num=21)
d = est2

sample = ['DS0097_035']
# sample = ['SS01105']

d = d.sel(sample=sample)
plt.hist(d.metagenotypes.dominant_allele_fraction().values.T, bins=bins)
for freq in d.communities.values.squeeze():
    plt.axvline(1 - freq, color='k', lw=1, linestyle='--')
plt.xlim(0.5, 1.0)

In [None]:
sf.plot.plot_fuzzed_genotype(
    est2,
#     col_linkage_func=lambda w: w.fuzzed_genotypes.cosine_linkage(),
    col_colors_func=lambda w: xr.Dataset(dict(
        abundance=w.communities.mean("sample").pipe(np.sqrt),
        entropy=w.genotypes.entropy,
        missing=1 - w.missingness.mean("position"),
#         which_fit=w.data['_concat_from'].to_series().map({'est3': 1, 'est4': 0}).to_xarray(),
    )),
)

In [None]:
sf.plot.plot_missing(
    est2,
    col_colors_func=lambda w: xr.Dataset(dict(
        abundance=w.communities.mean("sample").pipe(np.sqrt),
        entropy=w.genotypes.entropy,
        missing=1 - w.missingness.mean("position"),
#         which_fit=w.data['_concat_from'].to_series().map({'est3': 1, 'est4': 0}).to_xarray(),
    )),
)

In [None]:
sf.plot.plot_metagenotype(
    est2,
    col_linkage_func=lambda w: w.communities.linkage(dim='sample'),
#     row_linkage_func=lambda w: w.metagenotypes.linkage(dim='position'),
    col_colors_func=lambda w: xr.Dataset(dict(
        mu=w.data.mu.pipe(np.sqrt),
        alpha=w.data.alpha.pipe(np.sqrt),
        max_frac=w.communities.max('strain').rename('max_frac'),
        m_hyper_r=w.data.m_hyper_r.pipe(np.cbrt),
    )),
#     row_linkage=None,
#     metric='euclidean',
)

## Benchmarking

In [None]:
results = []

for sim_seed in range(5):
    for fit_seed in range(5):
        res = sf.workflow.simulation_benchmark(
            nsample=100,
            nposition=500,
            sim_model=sf.model_zoo.hybrid_fuzzy_missing_dp_betabinomial_metagenotype,
            sim_nstrain=20,
            sim_data=dict(
                m_hyper_r=4.,
            ),
            sim_hyperparameters=dict(
                gamma_hyper=0.001,
                delta_hyper_r=0.85,
                delta_hyper_temp=0.001,
                rho_hyper=3.,
                pi_hyper=0.2,
                alpha_hyper_hyper_mean=200.0,
                alpha_hyper_hyper_scale=1.0,
                alpha_hyper_scale=1.0,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
                mu_hyper_mean=10.0,
                mu_hyper_scale=1.5,
            ),
            fit_hyperparameters=dict(
                gamma_hyper=0.01,
                delta_hyper_r=0.9,
                delta_hyper_temp=0.1,
                rho_hyper=0.01,
                pi_hyper=1.,
                alpha_hyper_hyper_mean=100.0,
                alpha_hyper_hyper_scale=1.0,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
            ),
            fit_nstrain=20,
            fit_seed=fit_seed,
            sim_seed=sim_seed,
            lagA=1,
            lagB=10,
            quiet=True,
        )
        res = (sim_seed, fit_seed, *res)
        print(res)
        results.append(res)
        
results0 = pd.DataFrame(results, columns=['sim_seed', 'fit_seed', 'genotype_error', 'community_error', 'runtime'])

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(8, 2))

for stat, ax in zip(['genotype_error', 'community_error', 'runtime'], axs.flatten()):
    results0.set_index(['fit_seed', 'sim_seed'])[stat].unstack().plot(ax=ax)
    ax.set_title(stat)
    ax.legend_.set_visible(False)
fig.tight_layout()

In [None]:
results = []

for sim_seed in range(5):
    for nsample in [10, 25, 50, 100, 200]:
        res = sf.workflow.simulation_benchmark(
            nsample=nsample,
            nposition=500,
            sim_model=sf.model_zoo.hybrid_fuzzy_missing_dp_betabinomial_metagenotype,
            sim_nstrain=20,
            sim_data=dict(
                m_hyper_r=4.,
            ),
            sim_hyperparameters=dict(
                gamma_hyper=0.001,
                delta_hyper_r=0.85,
                delta_hyper_temp=0.001,
                rho_hyper=3.,
                pi_hyper=0.2,
                alpha_hyper_hyper_mean=200.0,
                alpha_hyper_hyper_scale=1.0,
                alpha_hyper_scale=1.0,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
                mu_hyper_mean=10.0,
                mu_hyper_scale=1.5,
            ),
            fit_hyperparameters=dict(
                gamma_hyper=0.01,
                delta_hyper_r=0.9,
                delta_hyper_temp=0.1,
                rho_hyper=0.01,
                pi_hyper=1.,
                alpha_hyper_hyper_mean=100.0,
                alpha_hyper_hyper_scale=1.0,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
            ),
            fit_nstrain=20,
            fit_seed=1,
            sim_seed=sim_seed,
            lagA=1,
            lagB=10,
            quiet=True,
        )
        res = (sim_seed, nsample, *res)
        print(res)
        results.append(res)
        
results1 = pd.DataFrame(results, columns=['sim_seed', 'nsample', 'genotype_error', 'community_error', 'runtime'])

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(8, 2))

for stat, ax in zip(['genotype_error', 'community_error', 'runtime'], axs.flatten()):
    results1.set_index(['nsample', 'sim_seed'])[stat].unstack().plot(ax=ax)
    ax.set_title(stat)
    ax.legend_.set_visible(False)
    ax.set_yscale('log')
    ax.set_xscale('log')
fig.tight_layout()

In [None]:
results = []

for sim_seed in range(5):
    for nposition in [20, 50, 100, 200, 500, 1000, 2000]:
        res = sf.workflow.simulation_benchmark(
            nsample=100,
            nposition=nposition,
            sim_model=sf.model_zoo.hybrid_fuzzy_missing_dp_betabinomial_metagenotype,
            sim_nstrain=20,
            sim_data=dict(
                m_hyper_r=4.,
            ),
            sim_hyperparameters=dict(
                gamma_hyper=0.001,
                delta_hyper_r=0.85,
                delta_hyper_temp=0.001,
                rho_hyper=3.,
                pi_hyper=0.2,
                alpha_hyper_hyper_mean=200.0,
                alpha_hyper_hyper_scale=1.0,
                alpha_hyper_scale=1.0,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
                mu_hyper_mean=10.0,
                mu_hyper_scale=1.5,
            ),
            fit_hyperparameters=dict(
                gamma_hyper=0.01,
                delta_hyper_r=0.9,
                delta_hyper_temp=0.1,
                rho_hyper=0.01,
                pi_hyper=1.,
                alpha_hyper_hyper_mean=100.0,
                alpha_hyper_hyper_scale=1.0,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
            ),
            fit_nstrain=20,
            fit_seed=1,
            sim_seed=sim_seed,
            lagA=1,
            lagB=10,
            quiet=True,
        )
        res = (sim_seed, nposition, *res)
        print(res)
        results.append(res)
        
results2 = pd.DataFrame(results, columns=['sim_seed', 'nposition', 'genotype_error', 'community_error', 'runtime'])

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(8, 2))

for stat, ax in zip(['genotype_error', 'community_error', 'runtime'], axs.flatten()):
    results2.set_index(['nposition', 'sim_seed'])[stat].unstack().plot(ax=ax)
    ax.set_title(stat)
    ax.legend_.set_visible(False)
    ax.set_yscale('log')
    ax.set_xscale('log')
fig.tight_layout()

In [None]:
results = []

for sim_seed in range(5):
    for fit_gamma_hyper in [1.1, 1.0, 0.2, 0.1, 0.05, 0.01]:
        res = sf.workflow.simulation_benchmark(
            nsample=100,
            nposition=500,
            sim_model=sf.model_zoo.hybrid_fuzzy_missing_dp_betabinomial_metagenotype,
            sim_nstrain=20,
            sim_data=dict(
                m_hyper_r=4.,
            ),
            sim_hyperparameters=dict(
                gamma_hyper=0.001,
                delta_hyper_r=0.85,
                delta_hyper_temp=0.001,
                rho_hyper=3.,
                pi_hyper=0.2,
                alpha_hyper_hyper_mean=200.0,
                alpha_hyper_hyper_scale=1.0,
                alpha_hyper_scale=1.0,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
                mu_hyper_mean=10.0,
                mu_hyper_scale=1.5,
            ),
            fit_hyperparameters=dict(
                gamma_hyper=fit_gamma_hyper,
                delta_hyper_r=0.9,
                delta_hyper_temp=0.1,
                rho_hyper=0.01,
                pi_hyper=1.,
                alpha_hyper_hyper_mean=100.0,
                alpha_hyper_hyper_scale=1.0,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
            ),
            fit_nstrain=20,
            fit_seed=1,
            sim_seed=sim_seed,
            lagA=1,
            lagB=10,
            quiet=True,
        )
        res = (sim_seed, fit_gamma_hyper, *res)
        print(res)
        results.append(res)
        
results3 = pd.DataFrame(results, columns=['sim_seed', 'fit_gamma_hyper', 'genotype_error', 'community_error', 'runtime'])

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(8, 2))

for stat, ax in zip(['genotype_error', 'community_error', 'runtime'], axs.flatten()):
    results3.set_index(['fit_gamma_hyper', 'sim_seed'])[stat].unstack().plot(ax=ax)
    ax.set_title(stat)
    ax.legend_.set_visible(False)
    ax.set_yscale('log')
    ax.set_xscale('log')
fig.tight_layout()

In [None]:
results = []

for sim_seed in range(5):
    for fit_rho_hyper in [1.0, 0.5, 0.1, 0.01, 0.001]:
        res = sf.workflow.simulation_benchmark(
            nsample=100,
            nposition=500,
            sim_model=sf.model_zoo.hybrid_fuzzy_missing_dp_betabinomial_metagenotype,
            sim_nstrain=20,
            sim_data=dict(
                m_hyper_r=4.,
            ),
            sim_hyperparameters=dict(
                gamma_hyper=0.001,
                delta_hyper_r=0.85,
                delta_hyper_temp=0.001,
                rho_hyper=3.,
                pi_hyper=0.2,
                alpha_hyper_hyper_mean=200.0,
                alpha_hyper_hyper_scale=1.0,
                alpha_hyper_scale=1.0,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
                mu_hyper_mean=10.0,
                mu_hyper_scale=1.5,
            ),
            fit_hyperparameters=dict(
                gamma_hyper=0.02,
                delta_hyper_r=0.9,
                delta_hyper_temp=0.1,
                rho_hyper=fit_rho_hyper,
                pi_hyper=1.,
                alpha_hyper_hyper_mean=100.0,
                alpha_hyper_hyper_scale=1.0,
                alpha_hyper_scale=0.5,
                epsilon_hyper_alpha=1.5,
                epsilon_hyper_beta=1.5 / 0.01,
            ),
            fit_nstrain=20,
            fit_seed=1,
            sim_seed=sim_seed,
            lagA=1,
            lagB=10,
            quiet=True,
        )
        res = (sim_seed, fit_rho_hyper, *res)
        print(res)
        results.append(res)
        
results4 = pd.DataFrame(results, columns=['sim_seed', 'fit_rho_hyper', 'genotype_error', 'community_error', 'runtime'])

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(8, 2))

for stat, ax in zip(['genotype_error', 'community_error', 'runtime'], axs.flatten()):
    results4.set_index(['fit_rho_hyper', 'sim_seed'])[stat].unstack().plot(ax=ax)
    ax.set_title(stat)
    ax.legend_.set_visible(False)
    ax.set_yscale('log')
    ax.set_xscale('log')
fig.tight_layout()