In [None]:
%load_ext autoreload

In [None]:
%autoreload

In [None]:
import sys
sys.path.append('/pollard/home/bsmith/Projects/haplo-benchmark/include/StrainFacts')

In [None]:
import sfacts as sf
import xarray as xr
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from itertools import product
from tqdm import tqdm
import pandas as pd
from lib.plot import construct_ordered_pallete

In [None]:
# data/sfacts_simulate-model_default-n1000-g10000-s20-rho150-pi10-mu50-eps10-alpha100-seed0.
s, rho, mu = 20, 150, 5  # 10,50 / 50,100

sim_prefix = f"data/sfacts_simulate-model_default-n1000-g10000-s{s}-rho{rho}-pi10-mu{mu}-eps10-alpha100-seed0"
sim_path = f"{sim_prefix}.world.nc"
sim = sf.data.World.load(sim_path)

In [None]:
# _, sim = sf.workflow.simulate_world(
#     sf.model_zoo.NAMED_STRUCTURES['default'],
#     sizes=dict(sample=1000, position=10000, strain=200),
#     hyperparameters=dict(
#         gamma_hyper=1e-5,
#         rho_hyper=50.,
#         pi_hyper=0.2,
#         mu_hyper_mean=2.0,
#         mu_hyper_scale=1.0,
#         epsilon_hyper_mode=1e-2,
#         epsilon_hyper_spread=1.5,
#         m_hyper_r_mean=5.,
#         m_hyper_r_scale=1.,
#         alpha_hyper_mean=100,
#         alpha_hyper_scale=1.0,
#     ),
#     seed=0,
# )

bins = np.linspace(-1, 0, num=51)
fig, ax0 = plt.subplots()
ax0.set_xlim(-1, 0)
ax1 = ax0.twinx()

ax0.hist(-sim.communities.max('sample'), bins=bins, density=False, cumulative=True, histtype='step', label='strain max', color='blue')
ax0.legend(loc='lower left')

ax1.hist(-sim.communities.max('strain'), bins=bins, density=False, cumulative=True, histtype='step', label='max strain', color='orange')
ax1.legend(loc='lower right')
None

In [None]:
sf.plot.plot_metagenotype(
    sim.isel(position=slice(0, 500), sample=slice(0, 100)),
    col_linkage_func=lambda w: w.communities.linkage('sample'),
)
sf.plot.plot_community(
    sim.isel(position=slice(0, 500), sample=slice(0, 100)),
    col_linkage_func=lambda w: w.communities.linkage('sample'),
)

In [None]:
# data/sfacts_simulate-model_default-n1000-g10000-s20-rho150-pi10-mu50-eps10-alpha100-seed0.metagenotype-n100-g500.{fit-sfacts1-s20-seed1,fit-sfinder-s20-seed1}.world.nc

n = 100
g = 500
seed = 0
s = 30
fit0_path = f"{sim_prefix}.metagenotype-n{n}-g{g}.fit-sfacts1-s{s}-seed{seed}.world.nc"
fit0 = sf.data.World.load(fit0_path)
fit1_path = f"{sim_prefix}.metagenotype-n{n}-g{g}.fit-sfinder-s{s}-seed{seed}.world.nc"
fit1 = sf.data.World.load(fit1_path)

sim0 = sim.sel(position=fit0.position, sample=fit0.sample)

In [None]:
pd.DataFrame([
    (
        seed,
        n,
        g,
        'sfacts',
        sf.evaluation.metagenotype_error(sim0, fit0)[0],
        sf.evaluation.community_error(sim0, fit0)[0],
        sf.evaluation.integrated_community_error(sim0, fit0)[0],
        sf.evaluation.integrated_community_error(fit0, sim0)[0],
        sf.evaluation.rank_abundance_error(sim0, fit0)[0],
        sf.evaluation.weighted_genotype_error(sim0, fit0),
        sf.evaluation.weighted_genotype_error(fit0, sim0),
    ),
    (
        seed,
        n,
        g,
        'sfinder',
        sf.evaluation.metagenotype_error(sim0, fit1)[0],
        sf.evaluation.community_error(sim0, fit1)[0],
        sf.evaluation.integrated_community_error(sim0, fit1)[0],
        sf.evaluation.integrated_community_error(fit1, sim0)[0],
        sf.evaluation.rank_abundance_error(sim0, fit1)[0],
        sf.evaluation.weighted_genotype_error(sim0, fit1),
        sf.evaluation.weighted_genotype_error(fit1, sim0),
    )
], columns=[
    "seed",
    "n",
    "g",
    "tool",
    "metagenotype_error",
    "community_error",
    "integrated_community_error",
    "integrated_community_error_flip",
    "rank_abundance_error",
    "weighted_genotype_error",
    "weighted_genotype_error_flip",
])

In [None]:
sf.plot.plot_community(
    sim0,
    col_linkage_func=lambda w: sim0.communities.linkage('sample'),
    row_colors_func=lambda w: xr.Dataset(dict(
        gen0_err=sf.evaluation.genotype_error(w, fit0)[1],
        gen1_err=sf.evaluation.genotype_error(w, fit1)[1],
    )),
)
sf.plot.plot_community(
    fit0,
    col_linkage_func=lambda w: sim0.communities.linkage('sample'),
    col_colors_func=lambda w: xr.Dataset(dict(
        com_err=sf.evaluation.community_error(sim0, w)[1],
        ice_err=sf.evaluation.integrated_community_error(sim0, w)[1],
        rank_err=sf.evaluation.rank_abundance_error(sim0, w)[1],
        mgn_err=sf.evaluation.metagenotype_error(sim0, w)[1],
    )),
    row_colors_func=lambda w: xr.Dataset(dict(
        gen_err=sf.evaluation.genotype_error(w, sim0)[1]
    )),
)
sf.plot.plot_community(
    fit1,
    col_linkage_func=lambda w: sim0.communities.linkage('sample'),
    col_colors_func=lambda w: xr.Dataset(dict(
        com_err=sf.evaluation.community_error(sim0, w)[1],
        ice_err=sf.evaluation.integrated_community_error(sim0, w)[1],
        rank_err=sf.evaluation.rank_abundance_error(sim0, w)[1],
        mgn_err=sf.evaluation.metagenotype_error(sim0, w)[1],
    )),
    row_colors_func=lambda w: xr.Dataset(dict(
        gen_err=sf.evaluation.genotype_error(w, sim0)[1]
    )),
)

In [None]:
fit_path

In [None]:
# data/strain_facts_simulate-model_default-n1000-g10000-s20-rho150-pi10-mu{1,5,10,20,50,1000}-eps10-alpha100-seed0
# .metagenotype-n{10,50,100,200}-g{100,200,500,1000,2000}
# .strain_facts_fit1.world.nc

true_s = 20

benchmarks = []
for sim_seed in [0]:
    sim_prefix = f"data/sfacts_simulate-model_default-n1000-g10000-s{true_s}-rho150-pi10-mu5-eps10-alpha100-seed{sim_seed}"
    sim_path = f"{sim_prefix}.world.nc"
    sim = sf.data.World.load(sim_path)
    
    for n, g, fit_type, fit_s, fit_seed in tqdm(product([100], [500], ['sfacts1', 'sfinder'], [14], range(10))):
        try:
            fit_path = f"{sim_prefix}.metagenotype-n{n}-g{g}.fit-{fit_type}-s{fit_s}-seed{fit_seed}.world.nc"
            fit = sf.data.World.load(fit_path)
            sim = sim.sel(position=fit.position, sample=fit.sample)
        except FileNotFoundError:
            print(f"{fit_path} not found")
            continue

        benchmarks.append((
            true_s,
            sim_seed,
            n,
            g,
            fit_seed,
            fit_s,
            fit_type,
            sf.evaluation.metagenotype_error(sim0, fit)[0],
            sf.evaluation.community_error(sim0, fit)[0],
            sf.evaluation.integrated_community_error(sim0, fit)[0],
            sf.evaluation.integrated_community_error(fit, sim0)[0],
            sf.evaluation.rank_abundance_error(sim0, fit)[0],
            sf.evaluation.weighted_genotype_error(sim0, fit),
            sf.evaluation.weighted_genotype_error(fit, sim0),
        ))
benchmarks = pd.DataFrame(benchmarks, columns=[
    "true_s",
    "sim_seed",
    "n",
    "g",
    "fit_seed",
    "fit_s",
    "fit_type",
    "metagenotype_error",
    "community_error",
    "integrated_community_error",
    "integrated_community_error_flip",
    "rank_abundance_error",
    "weighted_genotype_error",
    "weighted_genotype_error_flip",
])

In [None]:
benchmarks

In [None]:
metrics = [    "metagenotype_error",
    "community_error",
    "integrated_community_error",
    "integrated_community_error_flip",
    "rank_abundance_error",
    "weighted_genotype_error",
#     "weighted_genotype_error_flip",
          ]
ncol = int(3)
nrow = int(np.ceil(len(metrics) / ncol))

fig, axs = plt.subplots(nrow, ncol, figsize=(5 * ncol, 5 * nrow))

for met, ax in zip(metrics, axs.flatten()):
    sns.swarmplot(x='fit_type', y=met, data=benchmarks, ax=ax)
    
fig.tight_layout()