In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sfacts as sf
import pyro
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
from functools import partial
import xarray as xr
import warnings
import torch

mpl.rcParams['figure.dpi'] = 70

def min_max_normalize(x):
    return (x - x.min()) / (x.max() - x.min())

In [None]:
import pandas as pd

In [None]:
import scipy as sp

In [None]:
warnings.filterwarnings(
    "ignore",
    message="torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.",
    category=torch.jit.TracerWarning,
#     module="trace_elbo",  # FIXME: What is the correct regex for module?
#     lineno=5,
)

### Plot Beta Distribution

In [None]:
xx = np.linspace(0, 1, num=1000)

for gamma_hyper in [1.1, 1.0, 0.5, 0.1]:
    plt.plot(xx, sp.stats.beta(gamma_hyper, gamma_hyper).pdf(xx), label=gamma_hyper, lw=3, alpha=0.7)
plt.legend(title=r'$\gamma^*$')
plt.yscale('log')
plt.ylabel('PDF')
plt.xlabel('x')

### Show Real Data Metagenotype Heatmaps

In [None]:
# Sanity check on sfacts/data.py
obs = (
    sf.data.Metagenotypes.load('data/ucfmt.sp-100022.gtpro-pileup.nc')
    .select_variable_positions(incid_thresh=0.2)
    .select_samples_with_coverage(0.1)
    .to_world()
)

obs.metagenotypes.to_estimated_genotypes().validate_constraints()

d = obs
d = d.sel(sample=d.metagenotypes.mean_depth("sample") > 10)
d = d.sel(sample=d.metagenotypes.entropy("sample") < 100)
d = d.sel(position=d.metagenotypes.total_counts().astype(bool).var("sample") < 0.05)

print(obs.sizes)
sf.plot.plot_metagenotype(
    (
        d
#         .isel(position=range(1000))
    ),
    col_colors_func=None,
#     col_colors_func=(
#         lambda w: (
#             w
#             .metagenotypes
#             .sum('allele')
#             .mean('position')
#             .pipe(np.sqrt)
#             .rename('mean_depth')
#         )
#     ),
)

In [None]:
# Sanity check on sfacts/data.py
obs = (
    sf.data.Metagenotypes.load('data/ucfmt.sp-100022.gtpro-pileup.nc')
    .select_variable_positions(incid_thresh=0.2)
    .select_samples_with_coverage(0.1)
    .to_world()
)

obs.metagenotypes.to_estimated_genotypes().validate_constraints()

d = obs
# d = d.sel(sample=d.metagenotypes.mean_depth("sample") > 10)
# d = d.sel(sample=d.metagenotypes.entropy("sample") < 100)
# d = d.sel(position=d.metagenotypes.total_counts().astype(bool).var("sample") < 0.05)

print(obs.sizes)
sf.plot.plot_metagenotype(
    (
        d
#         .isel(position=range(1000))
    ),
    col_colors_func=None,
#     col_colors_func=(
#         lambda w: (
#             w
#             .metagenotypes
#             .sum('allele')
#             .mean('position')
#             .pipe(np.sqrt)
#             .rename('mean_depth')
#         )
#     ),
)

In [None]:
d = obs
# d = d.sel(sample=d.metagenotypes.mean_depth("sample") > 10)
# d = d.sel(sample=d.metagenotypes.entropy("sample") < 100)
# d = d.sel(position=d.metagenotypes.total_counts().astype(bool).var("sample") < 0.05)

print(obs.sizes)
sf.plot.plot_depth(
    (
        d
#         .isel(position=range(1000))
    ),
    col_colors_func=None,
#     col_colors_func=(
#         lambda w: (
#             w
#             .metagenotypes
#             .sum('allele')
#             .mean('position')
#             .pipe(np.sqrt)
#             .rename('mean_depth')
#         )
#     ),
)

### Constructed Examples

#### Demo Super-Simple Model

In [None]:
n, g, s = 100, 500, 10

sim_model = sf.model.ParameterizedModel(
    sf.model_zoo.simple_metagenotype,
    coords=dict(
        sample=range(n),
        position=g,
        strain=s,
        allele=['alt', 'ref'],
    ),
    hyperparameters=dict(
        gamma_hyper=0.01,
        pi_hyper=0.1,
    )
)
# print(sim_model.data, sim_model.hyperparameters)

sim0 = sim_model.simulate_world(seed=2)

sim0

In [None]:
plt.hist(_world.data.m.values.flatten())

In [None]:
_world = sim0

sf.plot.plot_metagenotype(_world)

In [None]:
_world = sim0
sf.plot.plot_community(_world)

In [None]:
_world = sim0
sf.plot.plot_genotype(_world)

#### Demo 2

In [None]:
n, g, s = 4, 500, 3

sim_model = sf.model.ParameterizedModel(
    sf.model_zoo.hybrid_fuzzy_missing_dp_betabinomial_metagenotype,
    coords=dict(
        sample=range(n),
        position=g,
        strain=s,
        allele=['alt', 'ref'],
    ),
    data=dict(
#         m_hyper_r=1000. * np.ones((n, 1)),
#         mu=20 * np.ones(n),
        alpha=1e5 * np.ones(n),
#         alpha=1e1 * np.ones(n),
        m=100 * np.ones((n, g)),
        gamma=[[0.] * 500,
               [1.] * 100 + [0.] * 400,
               [1.] * 20 + [0.] * 480,
              ],
        delta=np.ones((s, g)),
        epsilon=1e-5 * np.ones(n),
#         rho=np.ones(s) / s,
        pi=[[1.0, 0.0, 0.0],
            [0.7, 0.3, 0.0],
            [0.8, 0.1, 0.1],
            [0.34, 0.33, 0.33],
           ],
    ),
    hyperparameters=dict(
        gamma_hyper=0.001,
        delta_hyper_r=0.85,
        delta_hyper_temp=0.001,
        rho_hyper=3.,
        pi_hyper=2.0,
        alpha_hyper_hyper_mean=200.0,
        alpha_hyper_hyper_scale=1.0,
        alpha_hyper_scale=1.0,
        epsilon_hyper_alpha=1.5,
        epsilon_hyper_beta=1.5 / 0.01,
        mu_hyper_mean=10.0,
        mu_hyper_scale=1.5,
        m_hyper_r_scale=1,
    )
)
# print(sim_model.data, sim_model.hyperparameters)

sim1 = sim_model.simulate_world(seed=2)

sim1

In [None]:
list(enumerate(sim1.data.variables))

In [None]:
sf.plot.plot_community(sim1)

In [None]:
sf.plot.plot_metagenotype_frequency_spectrum(sim1, show_dominant=True)

In [None]:
sf.plot.plot_metagenotype_frequency_spectrum(sim1, sample_list=[1], show_dominant=True, axwidth=4, axheight=3)

In [None]:
sf.plot.plot_genotype(sim1)

In [None]:
sf.plot.plot_metagenotype(sim1)

#### Demo Frequency Spectrums

In [None]:
n, g, s = 44, 500, 4

sim_model = sf.model.ParameterizedModel(
    sf.model_zoo.hybrid_fuzzy_missing_dp_betabinomial_metagenotype,
    coords=dict(
        sample=[
            '0.0,0.2,1e-05,1e+03', '0.1,0.2,1e-05,1e+03', '0.3,0.2,1e-05,1e+03', '0.5,0.2,1e-05,1e+03',
                                   '0.1,0.4,1e-05,1e+03', '0.3,0.4,1e-05,1e+03', '0.5,0.4,1e-05,1e+03',
                                   '0.1,0.6,1e-05,1e+03', '0.3,0.6,1e-05,1e+03', '0.5,0.6,1e-05,1e+03',
            '0.25,-1,1e-05,1e+03',
            '0.0,0.2,1e-01,1e+03', '0.1,0.2,1e-01,1e+03', '0.3,0.2,1e-01,1e+03', '0.5,0.2,1e-01,1e+03',
                                   '0.1,0.4,1e-01,1e+03', '0.3,0.4,1e-01,1e+03', '0.5,0.4,1e-01,1e+03',
                                   '0.1,0.6,1e-01,1e+03', '0.3,0.6,1e-01,1e+03', '0.5,0.6,1e-01,1e+03',
            '0.25,-1,1e-01,1e+03',
            '0.0,0.2,1e-05,1e+01', '0.1,0.2,1e-05,1e+01', '0.3,0.2,1e-05,1e+01', '0.5,0.2,1e-05,1e+01',
                                   '0.1,0.4,1e-05,1e+01', '0.3,0.4,1e-05,1e+01', '0.5,0.4,1e-05,1e+01',
                                   '0.1,0.6,1e-05,1e+01', '0.3,0.6,1e-05,1e+01', '0.5,0.6,1e-05,1e+01',
            '0.25,-1,1e-05,1e+01',
            '0.0,0.2,1e-01,1e+01', '0.1,0.2,1e-01,1e+01', '0.3,0.2,1e-01,1e+01', '0.5,0.2,1e-01,1e+01',
                                   '0.1,0.4,1e-01,1e+01', '0.3,0.4,1e-01,1e+01', '0.5,0.4,1e-01,1e+01',
                                   '0.1,0.6,1e-01,1e+01', '0.3,0.6,1e-01,1e+01', '0.5,0.6,1e-01,1e+01',
            '0.25,-1,1e-01,1e+01',
        ],
        position=g,
        strain=4,
        allele=['alt', 'ref'],
    ),
    data=dict(
#         m_hyper_r=1000. * np.ones((n, 1)),
#         mu=20 * np.ones(n),
        alpha=1e5 * np.ones(n),
#         alpha=1e1 * np.ones(n),
        m=[[1e3]*g] * (n//2) + [[1e1]*g] * (n//2),
        gamma=[[0.] * g,
               [1.] * 100 + [0.] * (g - 100),
               [1.] * 200 + [0.] * (g - 200),
               [1.] * 300 + [0.] * (g - 300),
              ],
        delta=np.ones((s, g)),
        epsilon=[1e-5] * (n // 4) + [1e-1] * (n // 4) + [1e-5] * (n // 4) + [1e-1] * (n // 4),
#         rho=np.ones(s) / s,
        pi=[
            [1.0, 0.0, 0.0, 0.0],
            [0.9, 0.1, 0.0, 0.0],
            [0.7, 0.3, 0.0, 0.0],
            [0.5, 0.5, 0.0, 0.0],
            [0.9, 0.0, 0.1, 0.0],
            [0.7, 0.0, 0.3, 0.0],
            [0.5, 0.0, 0.5, 0.0],
            [0.9, 0.0, 0.0, 0.1],
            [0.7, 0.0, 0.0, 0.3],
            [0.5, 0.0, 0.0, 0.5],
            [.25, .25, .25, .25],
        ] * 4,
    ),
    hyperparameters=dict(
#         gamma_hyper=0.001,
#         delta_hyper_r=0.85,
#         delta_hyper_temp=0.001,
#         rho_hyper=3.,
#         pi_hyper=2.0,
#         alpha_hyper_hyper_mean=200.0,
#         alpha_hyper_hyper_scale=1.0,
#         alpha_hyper_scale=1.0,
#         epsilon_hyper_alpha=1.5,
#         epsilon_hyper_beta=1.5 / 0.01,
        mu_hyper_mean=10.0,
        mu_hyper_scale=1.5,
        m_hyper_r_scale=1,
    )
)
# print(sim_model.data, sim_model.hyperparameters)

sim7 = sim_model.simulate_world(seed=2)

sim7

In [None]:
n, g, s = 44, 500, 4

sim_model = sf.model.ParameterizedModel(
    sf.model_zoo.hybrid_fuzzy_missing_dp_betabinomial_metagenotype,
    coords=dict(
        sample=[
            '0.0,0.2,1e-05,1e+03', '0.1,0.2,1e-05,1e+03', '0.3,0.2,1e-05,1e+03', '0.5,0.2,1e-05,1e+03',
                                   '0.1,0.4,1e-05,1e+03', '0.3,0.4,1e-05,1e+03', '0.5,0.4,1e-05,1e+03',
                                   '0.1,0.6,1e-05,1e+03', '0.3,0.6,1e-05,1e+03', '0.5,0.6,1e-05,1e+03',
            '0.25,-1,1e-05,1e+03',
            '0.0,0.2,1e-01,1e+03', '0.1,0.2,1e-01,1e+03', '0.3,0.2,1e-01,1e+03', '0.5,0.2,1e-01,1e+03',
                                   '0.1,0.4,1e-01,1e+03', '0.3,0.4,1e-01,1e+03', '0.5,0.4,1e-01,1e+03',
                                   '0.1,0.6,1e-01,1e+03', '0.3,0.6,1e-01,1e+03', '0.5,0.6,1e-01,1e+03',
            '0.25,-1,1e-01,1e+03',
            '0.0,0.2,1e-05,1e+01', '0.1,0.2,1e-05,1e+01', '0.3,0.2,1e-05,1e+01', '0.5,0.2,1e-05,1e+01',
                                   '0.1,0.4,1e-05,1e+01', '0.3,0.4,1e-05,1e+01', '0.5,0.4,1e-05,1e+01',
                                   '0.1,0.6,1e-05,1e+01', '0.3,0.6,1e-05,1e+01', '0.5,0.6,1e-05,1e+01',
            '0.25,-1,1e-05,1e+01',
            '0.0,0.2,1e-01,1e+01', '0.1,0.2,1e-01,1e+01', '0.3,0.2,1e-01,1e+01', '0.5,0.2,1e-01,1e+01',
                                   '0.1,0.4,1e-01,1e+01', '0.3,0.4,1e-01,1e+01', '0.5,0.4,1e-01,1e+01',
                                   '0.1,0.6,1e-01,1e+01', '0.3,0.6,1e-01,1e+01', '0.5,0.6,1e-01,1e+01',
            '0.25,-1,1e-01,1e+01',
        ],
        position=g,
        strain=4,
        allele=['alt', 'ref'],
    ),
    data=dict(
#         m_hyper_r=1000. * np.ones((n, 1)),
#         mu=20 * np.ones(n),
        alpha=1e2 * np.ones(n),
#         alpha=1e1 * np.ones(n),
        m=[[1e3]*g] * (n//2) + [[1e1]*g] * (n//2),
        gamma=[[0.] * g,
               [1.] * 100 + [0.] * (g - 100),
               [1.] * 200 + [0.] * (g - 200),
               [1.] * 300 + [0.] * (g - 300),
              ],
        delta=np.ones((s, g)),
        epsilon=[1e-5] * (n // 4) + [1e-1] * (n // 4) + [1e-5] * (n // 4) + [1e-1] * (n // 4),
#         rho=np.ones(s) / s,
        pi=[
            [1.0, 0.0, 0.0, 0.0],
            [0.9, 0.1, 0.0, 0.0],
            [0.7, 0.3, 0.0, 0.0],
            [0.5, 0.5, 0.0, 0.0],
            [0.9, 0.0, 0.1, 0.0],
            [0.7, 0.0, 0.3, 0.0],
            [0.5, 0.0, 0.5, 0.0],
            [0.9, 0.0, 0.0, 0.1],
            [0.7, 0.0, 0.0, 0.3],
            [0.5, 0.0, 0.0, 0.5],
            [.25, .25, .25, .25],
        ] * 4,
    ),
    hyperparameters=dict(
#         gamma_hyper=0.001,
#         delta_hyper_r=0.85,
#         delta_hyper_temp=0.001,
#         rho_hyper=3.,
#         pi_hyper=2.0,
#         alpha_hyper_hyper_mean=200.0,
#         alpha_hyper_hyper_scale=1.0,
#         alpha_hyper_scale=1.0,
#         epsilon_hyper_alpha=1.5,
#         epsilon_hyper_beta=1.5 / 0.01,
        mu_hyper_mean=10.0,
        mu_hyper_scale=1.5,
        m_hyper_r_scale=1,
    )
)
# print(sim_model.data, sim_model.hyperparameters)

sim7_low_alpha = sim_model.simulate_world(seed=2)

sim7_low_alpha

In [None]:
_world = sim7
sf.plot.plot_community(
    _world,
#     col_linkage_func=None,
#     row_linkage_func=None,
#     col_colors_func=None,
#     row_colors_func=None,
    col_cluster=False,
    row_cluster=False,
    col_colors=None,
    row_colors=None,
)

In [None]:
_world = sim7
sf.plot.plot_metagenotype(
    _world,
#     col_linkage_func=None,
#     row_linkage_func=None,
#     col_colors_func=None,
#     row_colors_func=None,
    col_cluster=False,
    row_cluster=False,
    col_colors=None,
    row_colors=None,
    xticklabels=0,
)

In [None]:
_world = sim7_low_alpha
sf.plot.plot_metagenotype(
    _world,
#     col_linkage_func=None,
#     row_linkage_func=None,
#     col_colors_func=None,
#     row_colors_func=None,
    col_cluster=False,
    row_cluster=False,
    col_colors=None,
    row_colors=None,
    xticklabels=0,
)

In [None]:
_world = sim7

color='darkblue'
depth = 1e+03

bins = np.linspace(0.5, 1.0, num=51)

error_rate_list = [1e-5]
minor_strain_frac_list = [0.1, 0.3, 0.5]
minor_genotype_frac_list = [0.2, 0.4, 0.6]

fig, axs = plt.subplots(3, 3, figsize=(3 * 3, 3 * 2), sharex=True, sharey=True)

for error_rate, alpha in zip(error_rate_list, [0.7, 0.25]):
    for minor_genotype_frac, row in zip(minor_genotype_frac_list, axs):
        for minor_strain_frac, ax in zip(minor_strain_frac_list, row):
                sample = f"{minor_strain_frac},{minor_genotype_frac},{error_rate:0.0e},{depth:0.0e}"
                sf.plot.plot_metagenotype_frequency_spectrum(_world, sample_list=[sample], show_dominant=False, axs=ax, bins=bins, color=color, alpha=alpha)
                ax.set_title("")
                ax.set_ylabel("")

ax.set_ylim(0, 500)
for minor_strain_frac, left_axs in zip(minor_genotype_frac_list, axs[:,0]):
    left_axs.set_ylabel(minor_strain_frac)
for minor_genotype_frac, top_axs in zip(minor_strain_frac_list, axs[0,:]):
    top_axs.set_title(minor_genotype_frac)
            
fig.tight_layout()

In [None]:
_world = sim7

color='darkblue'
depth = 1e+03

bins = np.linspace(0.5, 1.0, num=51)

error_rate_list = [1e-1, 1e-5]
minor_strain_frac_list = [0.1, 0.3, 0.5]
minor_genotype_frac_list = [0.2, 0.4, 0.6]

fig, axs = plt.subplots(3, 3, figsize=(3 * 3, 3 * 2), sharex=True, sharey=True)

for error_rate, alpha in zip(error_rate_list, [0.7, 0.25]):
    for minor_genotype_frac, row in zip(minor_genotype_frac_list, axs):
        for minor_strain_frac, ax in zip(minor_strain_frac_list, row):
                sample = f"{minor_strain_frac},{minor_genotype_frac},{error_rate:0.0e},{depth:0.0e}"
                sf.plot.plot_metagenotype_frequency_spectrum(_world, sample_list=[sample], show_dominant=False, axs=ax, bins=bins, color=color, alpha=alpha)
                ax.set_title("")
                ax.set_ylabel("")

ax.set_ylim(0, 500)
for minor_strain_frac, left_axs in zip(minor_genotype_frac_list, axs[:,0]):
    left_axs.set_ylabel(minor_strain_frac)
for minor_genotype_frac, top_axs in zip(minor_strain_frac_list, axs[0,:]):
    top_axs.set_title(minor_genotype_frac)
            
fig.tight_layout()

In [None]:
_world = sim7

color='darkblue'
depth = 1e+03

bins = np.linspace(0.5, 1.0, num=51)

error_rate_list = [1e-5]
minor_strain_frac_list = [0.1, 0.3, 0.5]
minor_genotype_frac_list = [0.2, 0.4, 0.6]

fig, axs = plt.subplots(3, 3, figsize=(3 * 3, 3 * 2), sharex=True, sharey=True)

for _world, alpha in zip([sim7, sim7_low_alpha], [0.15, 0.7]):
    for minor_genotype_frac, row in zip(minor_genotype_frac_list, axs):
        for minor_strain_frac, ax in zip(minor_strain_frac_list, row):
                sample = f"{minor_strain_frac},{minor_genotype_frac},{error_rate:0.0e},{depth:0.0e}"
                sf.plot.plot_metagenotype_frequency_spectrum(_world, sample_list=[sample], show_dominant=False, axs=ax, bins=bins, color=color, alpha=alpha)
                ax.set_title("")
                ax.set_ylabel("")

ax.set_ylim(0, 500)
for minor_strain_frac, left_axs in zip(minor_genotype_frac_list, axs[:,0]):
    left_axs.set_ylabel(minor_strain_frac)
for minor_genotype_frac, top_axs in zip(minor_strain_frac_list, axs[0,:]):
    top_axs.set_title(minor_genotype_frac)
            
fig.tight_layout()

In [None]:
_world = sim7

color='darkblue'
depth = 1e+01

bins = np.linspace(0.5, 1.0, num=51)

error_rate_list = [1e-5]
minor_strain_frac_list = [0.1, 0.3, 0.5]
minor_genotype_frac_list = [0.2, 0.4, 0.6]

fig, axs = plt.subplots(3, 3, figsize=(3 * 3, 3 * 2), sharex=True, sharey=True)

for error_rate, alpha in zip(error_rate_list, [0.7, 0.25]):
    for minor_genotype_frac, row in zip(minor_genotype_frac_list, axs):
        for minor_strain_frac, ax in zip(minor_strain_frac_list, row):
                sample = f"{minor_strain_frac},{minor_genotype_frac},{error_rate:0.0e},{depth:0.0e}"
                sf.plot.plot_metagenotype_frequency_spectrum(_world, sample_list=[sample], show_dominant=False, axs=ax, bins=bins, color=color, alpha=alpha)
                ax.set_title("")
                ax.set_ylabel("")

ax.set_ylim(0, 500)
for minor_strain_frac, left_axs in zip(minor_genotype_frac_list, axs[:,0]):
    left_axs.set_ylabel(minor_strain_frac)
for minor_genotype_frac, top_axs in zip(minor_strain_frac_list, axs[0,:]):
    top_axs.set_title(minor_genotype_frac)
            
fig.tight_layout()

In [None]:
_world = sim7

error_rate = 1e-05
color='darkblue'
depth = 1e+03
minor_strain_frac = 0.25
minor_genotype_frac = -1

bins = np.linspace(0.5, 1.0, num=51)

minor_strain_frac_list = [0.1, 0.3, 0.5]
minor_genotype_frac_list = [0.2, 0.4, 0.6]

fig, ax = plt.subplots(1, 1, figsize=(1 * 3, 1 * 2), sharex=True, sharey=True)
sample = f"{minor_strain_frac},{minor_genotype_frac},{error_rate:0.0e},{depth:0.0e}"
sf.plot.plot_metagenotype_frequency_spectrum(_world, sample_list=[sample], show_dominant=False, axs=ax, bins=bins, color=color, alpha=0.7)
ax.set_title("")
ax.set_ylabel("")

ax.set_ylim(0, 500)
for minor_strain_frac, left_axs in zip(minor_strain_frac_list, axs[:,0]):
    left_axs.set_ylabel(minor_strain_frac)
for minor_genotype_frac, top_axs in zip(minor_genotype_frac_list, axs[0,:]):
    top_axs.set_title(minor_genotype_frac)
            
fig.tight_layout()

In [None]:
_world = sim7
sf.plot.plot_genotype(_world)

In [None]:
_world = sim7
sf.plot.plot_metagenotype(_world)

In [None]:
d = sim7


fig, axs = plt.subplots(3, figsize=(15, 10))

ax = axs[0]
sf.plot.ordination_plot(
    d,
    dmat_func=lambda w: w.metagenotypes.cosine_pdist(),
    vmin=0,
    colors_func=lambda w: w.communities.to_pandas().apply(lambda x: x.drop(x.idxmax()).idxmax(), axis=1),
    sizes_func=lambda w: w.communities.max('strain') * 100,
    cmap=mpl.cm.tab20,
    edgecolor='k',
    lw=0.2,
    ax=ax,
)
ax.set_aspect('equal')

ax = axs[1]
sf.plot.ordination_plot(
    d,
    dmat_func=lambda w: w.metagenotypes.pdist(),
    vmin=0,
    colors_func=lambda w: w.communities.to_pandas().apply(lambda x: x.drop(x.idxmax()).idxmax(), axis=1),
    sizes_func=lambda w: w.communities.max('strain') * 100,
    cmap=mpl.cm.tab20,
    edgecolor='k',
    lw=0.2,
    ax=ax,
)
ax.set_aspect('equal')

ax = axs[2]
sf.plot.ordination_plot(
    d,
    dmat_func=lambda w: w.metagenotypes.to_estimated_genotypes().cosine_pdist(),
    vmin=0,
    colors_func=lambda w: w.communities.to_pandas().apply(lambda x: x.drop(x.idxmax()).idxmax(), axis=1),
    sizes_func=lambda w: w.communities.max('strain') * 100,
    cmap=mpl.cm.tab20,
    edgecolor='k',
    lw=0.2,
    ax=ax,
)
ax.set_aspect('equal')

None

### Simple Simulated

In [None]:
n, g, s = 200, 500, 3

sim_model = sf.model.ParameterizedModel(
    sf.model_zoo.hybrid_fuzzy_missing_dp_betabinomial_metagenotype,
    coords=dict(
        sample=range(n),
        position=g,
        strain=s,
        allele=['alt', 'ref'],
    ),
    data=dict(
#         m_hyper_r=1000. * np.ones((n, 1)),
#         mu=20 * np.ones(n),
        alpha=1e5 * np.ones(n),
#         alpha=1e1 * np.ones(n),
        m=100 * np.ones((n, g)),
        gamma=[[0.] * 500,
               [1.] * 100 + [0.] * 400,
               [1.] * 20 + [0.] * 480,
              ],
        delta=np.ones((s, g)),
        epsilon=1e-5 * np.ones(n),
        rho=np.ones(s) / s,
#         pi=[[1.0, 0.0, 0.0],
#             [0.7, 0.3, 0.0],
#             [0.8, 0.1, 0.1],
#             [0.34, 0.33, 0.33],
#            ],
    ),
    hyperparameters=dict(
        gamma_hyper=0.001,
        delta_hyper_r=0.85,
        delta_hyper_temp=0.001,
        rho_hyper=3.,
        pi_hyper=0.5,
        alpha_hyper_hyper_mean=200.0,
        alpha_hyper_hyper_scale=1.0,
        alpha_hyper_scale=1.0,
        epsilon_hyper_alpha=1.5,
        epsilon_hyper_beta=1.5 / 0.01,
#         mu_hyper_mean=10.0,
#         mu_hyper_scale=1.5,
#         m_hyper_r_scale=1,
    )
)
# print(sim_model.data, sim_model.hyperparameters)

sim2 = sim_model.simulate_world(seed=2)

sim2

In [None]:
d = sim2

sf.plot.plot_community(d)

In [None]:
d = sim2

sf.plot.plot_genotype(d)

In [None]:
d = sim2


fig, axs = plt.subplots(3, figsize=(15, 10))

ax = axs[0]
sf.plot.ordination_plot(
    d,
    dmat_func=lambda w: w.metagenotypes.cosine_pdist(),
    vmin=0,
#     sizes_func=lambda w: w.data.mu.pipe(np.sqrt) * 10,
#     colors_func=lambda w: w.communities.max('strain'),
    sizes_func=lambda w: w.communities.max('strain')**(3) * 75,
    colors_func=lambda w: w.communities.to_pandas().idxmax(1),
#     colors_func=lambda w: w.data.alpha.pipe(np.sqrt),
    cmap=mpl.cm.tab20,
    edgecolor='k',
    lw=0.2,
    ax=ax,
)
ax.set_aspect('equal')

ax = axs[1]
sf.plot.ordination_plot(
    d,
    dmat_func=lambda w: w.metagenotypes.pdist(),
    vmin=0,
#     sizes_func=lambda w: w.data.mu.pipe(np.sqrt) * 10,
#     colors_func=lambda w: w.communities.max('strain'),
    sizes_func=lambda w: w.communities.max('strain')**(3) * 75,
    colors_func=lambda w: w.communities.to_pandas().idxmax(1),
#     colors_func=lambda w: w.data.alpha.pipe(np.sqrt),
    cmap=mpl.cm.tab20,
    edgecolor='k',
    lw=0.2,
    ax=ax,
)
ax.set_aspect('equal')

ax = axs[2]
sf.plot.ordination_plot(
    d,
    dmat_func=lambda w: w.metagenotypes.to_estimated_genotypes().cosine_pdist(),
    vmin=0,
#     sizes_func=lambda w: w.data.mu.pipe(np.sqrt) * 10,
#     colors_func=lambda w: w.communities.max('strain'),
    sizes_func=lambda w: w.communities.max('strain')**(3) * 75,
    colors_func=lambda w: w.communities.to_pandas().idxmax(1),
#     colors_func=lambda w: w.data.alpha.pipe(np.sqrt),
    cmap=mpl.cm.tab20,
    edgecolor='k',
    lw=0.2,
    ax=ax,
)
ax.set_aspect('equal')

None

In [None]:
sf.plot.plot_metagenotype(
    sim2,
#     col_linkage_func=lambda w: w.metagenotypes.to_estimated_genotypes().cosine_linkage(),
    col_colors_func=lambda w: xr.Dataset(dict(
#         mu=w.data.mu.pipe(np.sqrt),
#         alpha=w.data.alpha.pipe(np.sqrt),
#         m_hyper_r=w.data.m_hyper_r.pipe(np.cbrt),
        max_frac=w.communities.max('strain').rename('max_frac'),
#         max_strain=w.communities.to_pandas().idxmax(1),
    )),
)

### Simple Simulated (Low Coverage)

In [None]:
n, g, s = 200, 500, 3

sim_model = sf.model.ParameterizedModel(
    sf.model_zoo.hybrid_fuzzy_missing_dp_betabinomial_metagenotype,
    coords=dict(
        sample=range(n),
        position=g,
        strain=s,
        allele=['alt', 'ref'],
    ),
    data=dict(
#         m_hyper_r=1000. * np.ones((n, 1)),
#         mu=20 * np.ones(n),
        alpha=1e5 * np.ones(n),
#         alpha=1e1 * np.ones(n),
        m=1000 * np.ones((n, g)),
        gamma=[[0.] * 500,
               [1.] * 100 + [0.] * 400,
               [1.] * 20 + [0.] * 480,
              ],
        delta=np.ones((s, g)),
        epsilon=1e-5 * np.ones(n),
        rho=np.ones(s) / s,
#         pi=[[1.0, 0.0, 0.0],
#             [0.7, 0.3, 0.0],
#             [0.8, 0.1, 0.1],
#             [0.34, 0.33, 0.33],
#            ],
    ),
    hyperparameters=dict(
        gamma_hyper=0.001,
        delta_hyper_r=0.85,
        delta_hyper_temp=0.001,
        rho_hyper=3.,
        pi_hyper=0.5,
        alpha_hyper_hyper_mean=200.0,
        alpha_hyper_hyper_scale=1.0,
        alpha_hyper_scale=1.0,
        epsilon_hyper_alpha=1.5,
        epsilon_hyper_beta=1.5 / 0.01,
#         mu_hyper_mean=10.0,
#         mu_hyper_scale=1.5,
#         m_hyper_r_scale=1,
    )
)
# print(sim_model.data, sim_model.hyperparameters)

sim3 = sim_model.simulate_world(seed=2)

sim3

In [None]:
d = sim3

sf.plot.plot_metagenotype(
    d,
#     col_linkage_func=lambda w: w.metagenotypes.to_estimated_genotypes().cosine_linkage(),
    col_colors_func=lambda w: xr.Dataset(dict(
#         mu=w.data.mu.pipe(np.sqrt),
#         alpha=w.data.alpha.pipe(np.sqrt),
#         m_hyper_r=w.data.m_hyper_r.pipe(np.cbrt),
        max_frac=w.communities.max('strain').rename('max_frac'),
#         max_strain=w.communities.to_pandas().idxmax(1),
    )),
)
sf.plot.plot_metagenotype(
    d,
#     col_linkage_func=lambda w: w.metagenotypes.to_estimated_genotypes().cosine_linkage(),
    col_colors_func=lambda w: xr.Dataset(dict(
#         mu=w.data.mu.pipe(np.sqrt),
#         alpha=w.data.alpha.pipe(np.sqrt),
#         m_hyper_r=w.data.m_hyper_r.pipe(np.cbrt),
#         max_frac=w.communities.max('strain').rename('max_frac'),
        max_strain=w.communities.to_pandas().idxmax(1),
    )),
    row_col_annotation_cmap=mpl.cm.tab20,
)

In [None]:
d = sim3


fig, axs = plt.subplots(3, figsize=(15, 5))

ax = axs[0]
sf.plot.ordination_plot(
    d,
    dmat_func=lambda w: w.metagenotypes.cosine_pdist(),
    vmin=0,
#     sizes_func=lambda w: w.data.mu.pipe(np.sqrt) * 10,
#     colors_func=lambda w: w.communities.max('strain'),
    sizes_func=lambda w: w.communities.max('strain')**(3) * 75,
    colors_func=lambda w: w.communities.to_pandas().idxmax(1),
#     colors_func=lambda w: w.data.alpha.pipe(np.sqrt),
    cmap=mpl.cm.tab20,
    ax=ax,
)
ax.set_aspect('equal')

ax = axs[1]
sf.plot.ordination_plot(
    d,
    dmat_func=lambda w: w.metagenotypes.pdist(),
    vmin=0,
#     sizes_func=lambda w: w.data.mu.pipe(np.sqrt) * 10,
#     colors_func=lambda w: w.communities.max('strain'),
    sizes_func=lambda w: w.communities.max('strain')**(3) * 75,
    colors_func=lambda w: w.communities.to_pandas().idxmax(1),
#     colors_func=lambda w: w.data.alpha.pipe(np.sqrt),
    cmap=mpl.cm.tab20,
    ax=ax,
)
ax.set_aspect('equal')

ax = axs[2]
sf.plot.ordination_plot(
    d,
    dmat_func=lambda w: w.metagenotypes.to_estimated_genotypes().cosine_pdist(),
    vmin=0,
#     sizes_func=lambda w: w.data.mu.pipe(np.sqrt) * 10,
#     colors_func=lambda w: w.communities.max('strain'),
    sizes_func=lambda w: w.communities.max('strain')**(3) * 75,
    colors_func=lambda w: w.communities.to_pandas().idxmax(1),
#     colors_func=lambda w: w.data.alpha.pipe(np.sqrt),
    cmap=mpl.cm.tab20,
    ax=ax,
)
ax.set_aspect('equal')

None

### Simple Simulated (Variable Coverage)

In [None]:
n, g, s = 200, 500, 3

sim_model = sf.model.ParameterizedModel(
    sf.model_zoo.hybrid_fuzzy_missing_dp_betabinomial_metagenotype,
    coords=dict(
        sample=range(n),
        position=g,
        strain=s,
        allele=['alt', 'ref'],
    ),
    data=dict(
        m_hyper_r=1000. * np.ones((n, 1)),
#         mu=20 * np.ones(n),
        alpha=1e5 * np.ones(n),
#         alpha=1e1 * np.ones(n),
#         m=2 * np.ones((n, g)),
        gamma=[[0.] * 500,
               [1.] * 100 + [0.] * 400,
               [1.] * 20 + [0.] * 480,
              ],
        delta=np.ones((s, g)),
        epsilon=1e-5 * np.ones(n),
        rho=np.ones(s) / s,
#         pi=[[1.0, 0.0, 0.0],
#             [0.7, 0.3, 0.0],
#             [0.8, 0.1, 0.1],
#             [0.34, 0.33, 0.33],
#            ],
    ),
    hyperparameters=dict(
        gamma_hyper=0.001,
        delta_hyper_r=0.85,
        delta_hyper_temp=0.001,
        rho_hyper=3.,
        pi_hyper=0.5,
        alpha_hyper_hyper_mean=200.0,
        alpha_hyper_hyper_scale=1.0,
        alpha_hyper_scale=1.0,
        epsilon_hyper_alpha=1.5,
        epsilon_hyper_beta=1.5 / 0.01,
        mu_hyper_mean=10.0,
        mu_hyper_scale=1.5,
#         m_hyper_r_scale=1,
    )
)
# print(sim_model.data, sim_model.hyperparameters)

sim4 = sim_model.simulate_world(seed=2)

In [None]:
d = sim4

sf.plot.plot_metagenotype(
    d,
#     col_linkage_func=lambda w: w.metagenotypes.to_estimated_genotypes().cosine_linkage(),
    col_colors_func=lambda w: xr.Dataset(dict(
#         mu=w.data.mu.pipe(np.sqrt),
#         alpha=w.data.alpha.pipe(np.sqrt),
#         m_hyper_r=w.data.m_hyper_r.pipe(np.cbrt),
        max_frac=w.communities.max('strain').rename('max_frac'),
#         max_strain=w.communities.to_pandas().idxmax(1),
    )),
)
sf.plot.plot_metagenotype(
    d,
#     col_linkage_func=lambda w: w.metagenotypes.to_estimated_genotypes().cosine_linkage(),
    col_colors_func=lambda w: xr.Dataset(dict(
#         mu=w.data.mu.pipe(np.sqrt),
#         alpha=w.data.alpha.pipe(np.sqrt),
#         m_hyper_r=w.data.m_hyper_r.pipe(np.cbrt),
#         max_frac=w.communities.max('strain').rename('max_frac'),
        max_strain=w.communities.to_pandas().idxmax(1),
    )),
    row_col_annotation_cmap=mpl.cm.tab20,
)

In [None]:
d = sim4


fig, axs = plt.subplots(2, figsize=(10, 5))

ax = axs[0]
sf.plot.ordination_plot(
    d,
    dmat_func=lambda w: w.metagenotypes.cosine_pdist(),
    vmin=0,
    sizes_func=lambda w: w.data.mu.pipe(np.sqrt) * 10,
#     colors_func=lambda w: w.communities.max('strain'),
#     sizes_func=lambda w: w.communities.max('strain')**(3) * 75,
    colors_func=lambda w: w.communities.to_pandas().idxmax(1),
#     colors_func=lambda w: w.data.alpha.pipe(np.sqrt),
    cmap=mpl.cm.tab20,
    ax=ax,
)
ax.set_aspect('equal')

ax = axs[1]
sf.plot.ordination_plot(
    d,
    dmat_func=lambda w: w.metagenotypes.pdist(),
    vmin=0,
    sizes_func=lambda w: w.data.mu.pipe(np.sqrt) * 10,
#     colors_func=lambda w: w.communities.max('strain'),
#     sizes_func=lambda w: w.communities.max('strain')**(3) * 75,
    colors_func=lambda w: w.communities.to_pandas().idxmax(1),
#     colors_func=lambda w: w.data.alpha.pipe(np.sqrt),
    cmap=mpl.cm.tab20,
    ax=ax,
)
ax.set_aspect('equal')

None

### Simple Simulated (Complex Noise)

In [None]:
n, g, s = 200, 500, 3

sim_model = sf.model.ParameterizedModel(
    sf.model_zoo.hybrid_fuzzy_missing_dp_betabinomial_metagenotype,
    coords=dict(
        sample=range(n),
        position=g,
        strain=s,
        allele=['alt', 'ref'],
    ),
    data=dict(
#         m_hyper_r=1000. * np.ones((n, 1)),
#         mu=20 * np.ones(n),
        alpha=1e2 * np.ones(n),
#         alpha=1e1 * np.ones(n),
#         m=2 * np.ones((n, g)),
        gamma=[[0.] * 500,
               [1.] * 100 + [0.] * 400,
               [1.] * 20 + [0.] * 480,
              ],
        delta=np.ones((s, g)),
        epsilon=1e-2 * np.ones(n),
        rho=np.ones(s) / s,
#         pi=[[1.0, 0.0, 0.0],
#             [0.7, 0.3, 0.0],
#             [0.8, 0.1, 0.1],
#             [0.34, 0.33, 0.33],
#            ],
    ),
    hyperparameters=dict(
        gamma_hyper=0.001,
        delta_hyper_r=0.85,
        delta_hyper_temp=0.001,
        rho_hyper=3.,
        pi_hyper=0.5,
        alpha_hyper_hyper_mean=200.0,
        alpha_hyper_hyper_scale=1.0,
        alpha_hyper_scale=1.0,
        epsilon_hyper_alpha=1.5,
        epsilon_hyper_beta=1.5 / 0.01,
        mu_hyper_mean=10.0,
        mu_hyper_scale=1.5,
        m_hyper_r_scale=1,
    )
)
# print(sim_model.data, sim_model.hyperparameters)

sim5 = sim_model.simulate_world(seed=2)

In [None]:
d = sim5

sf.plot.plot_metagenotype(
    d,
#     col_linkage_func=lambda w: w.metagenotypes.to_estimated_genotypes().cosine_linkage(),
    col_colors_func=lambda w: xr.Dataset(dict(
#         mu=w.data.mu.pipe(np.sqrt),
#         alpha=w.data.alpha.pipe(np.sqrt),
#         m_hyper_r=w.data.m_hyper_r.pipe(np.cbrt),
        max_frac=w.communities.max('strain').rename('max_frac'),
#         max_strain=w.communities.to_pandas().idxmax(1),
    )),
)
sf.plot.plot_metagenotype(
    d,
#     col_linkage_func=lambda w: w.metagenotypes.to_estimated_genotypes().cosine_linkage(),
    col_colors_func=lambda w: xr.Dataset(dict(
#         mu=w.data.mu.pipe(np.sqrt),
#         alpha=w.data.alpha.pipe(np.sqrt),
#         m_hyper_r=w.data.m_hyper_r.pipe(np.cbrt),
#         max_frac=w.communities.max('strain').rename('max_frac'),
        max_strain=w.communities.to_pandas().idxmax(1),
    )),
    row_col_annotation_cmap=mpl.cm.tab20,
)

In [None]:
d = sim5


fig, axs = plt.subplots(2, figsize=(10, 5))

ax = axs[0]
sf.plot.ordination_plot(
    d,
    dmat_func=lambda w: w.metagenotypes.cosine_pdist(),
    vmin=0,
    sizes_func=lambda w: w.data.mu.pipe(np.sqrt) * 10,
#     colors_func=lambda w: w.communities.max('strain'),
#     sizes_func=lambda w: w.communities.max('strain')**(3) * 75,
    colors_func=lambda w: w.communities.to_pandas().idxmax(1),
#     colors_func=lambda w: w.data.alpha.pipe(np.sqrt),
    cmap=mpl.cm.tab20,
    ax=ax,
)
ax.set_aspect('equal')

ax = axs[1]
sf.plot.ordination_plot(
    d,
    dmat_func=lambda w: w.metagenotypes.pdist(),
    vmin=0,
    sizes_func=lambda w: w.data.mu.pipe(np.sqrt) * 10,
#     colors_func=lambda w: w.communities.max('strain'),
#     sizes_func=lambda w: w.communities.max('strain')**(3) * 75,
    colors_func=lambda w: w.communities.to_pandas().idxmax(1),
#     colors_func=lambda w: w.data.alpha.pipe(np.sqrt),
    cmap=mpl.cm.tab20,
    ax=ax,
)
ax.set_aspect('equal')

None

### Simulate Real-looking Data

In [None]:
sim_model = sf.model.ParameterizedModel(
    sf.model_zoo.hybrid_fuzzy_missing_dp_betabinomial_metagenotype,
    coords=dict(
        sample=150,
        position=1000,
        strain=20,
        allele=['alt', 'ref'],
    ),
    data=dict(
        m_hyper_r_mean=5.,
#         alpha=100 * np.ones(100),
#         epsilon=0.05 * np.ones(100),
#         alpha=10000 * np.ones(100),
#         epsilon=0.000001 * np.ones(100),
    ),
    hyperparameters=dict(
        gamma_hyper=0.001,
        delta_hyper_r=0.9,
        delta_hyper_temp=0.001,
        rho_hyper=4.,
        pi_hyper=0.4,
        alpha_hyper_hyper_mean=200.0,
        alpha_hyper_hyper_scale=1.0,
        alpha_hyper_scale=1.0,
        epsilon_hyper_alpha=1.5,
        epsilon_hyper_beta=1.5 / 0.01,
        mu_hyper_mean=30.0,
        mu_hyper_scale=1.5,
#         m_hyper_r_mu=5,
        m_hyper_r_scale=0.5,
        
    )
)
# print(sim_model.data, sim_model.hyperparameters)

sim6 = sim_model.simulate_world(seed=2)

In [None]:
_world = sim6

sf.plot.plot_genotype(
    _world,
    row_colors_func=None,
    col_linkage_func=lambda w: w.metagenotypes.linkage("position"),
    row_linkage_func=lambda w: w.genotypes.linkage("strain"),
    transpose=True,
)

In [None]:
_world = sim6

sf.plot.plot_missing(
    _world,
    row_colors_func=None,
    col_linkage_func=lambda w: w.metagenotypes.linkage("position"),
    row_linkage_func=lambda w: w.genotypes.linkage("strain"),
    transpose=True,
)

In [None]:
_world = sim6

sf.plot.plot_community(
    _world,
    row_colors_func=None,
    col_linkage_func=lambda w: w.metagenotypes.linkage("sample"),
    row_linkage_func=lambda w: w.genotypes.linkage("strain"),
)

In [None]:
_world = sim6

sf.plot.plot_metagenotype(
    _world,
    row_colors_func=None,
    col_colors_func=None,
    col_linkage_func=lambda w: w.metagenotypes.linkage("sample"),
    row_linkage_func=lambda w: w.metagenotypes.linkage("position"),
)

In [None]:
d = sim6


g = sf.plot.plot_metagenotype(
    d,
#     row_linkage_func=lambda w: w.metagenotypes.linkage(dim='position'),
#     col_linkage_func=lambda w: w.metagenotypes.linkage(dim='strain'),
#     metric='euclidean',
    col_colors_func=lambda w: xr.Dataset(dict(
        max_strain=w.communities.to_pandas().idxmax(1),
    )),
    row_col_annotation_cmap=mpl.cm.tab20,
)
# sf.plot.plot_genotype(sim, scalex=0.6, scaley=0.02, cwidth=0., cheight=0.1, dwidth=0.2, dheight=1.0)
# sf.plot._calculate_clustermap_sizes(10, 10, scalex=0.6, scaley=0.02, cwidth=0., cheight=0.1, dwidth=0.2, dheight=1.0)
# sf.plot.plot_genotype(sf.data.Metagenotypes.from_counts_and_totals(sim0.data['y'], sim0.data['m']))

In [None]:
d = sim6

fig, ax = plt.subplots(figsize=(5, 5))

sf.plot.ordination_plot(
    d,
    dmat_func=(
        lambda w:
        pd.DataFrame(
            sp.spatial.distance.squareform(
                sp.spatial.distance.pdist(
                    w.metagenotypes.data.to_dataframe().squeeze().unstack('sample').T,
                    'cosine',
                )
            ),
            index=w.sample, columns=w.sample
        )
    ),
    vmin=0,
#     sizes_func=lambda w: w.data.mu.pipe(np.sqrt) * 10,
#     colors_func=lambda w: w.communities.max('strain'),
    sizes_func=lambda w: w.communities.max('strain')**2 * 75,
    colors_func=lambda w: w.communities.to_pandas().idxmax(1),
#     colors_func=lambda w: w.data.alpha.pipe(np.sqrt),
    cmap=mpl.cm.tab20,
    ax=ax,
)
None

#### Fit Simulated Data

##### Toy Model

In [None]:
d = sim6

model_fit = (
    sf.model.ParameterizedModel(
        sf.model_zoo.simple_metagenotype2,
        coords=dict(
            sample=d.sample.values,
            position=d.position.values,
            allele=d.allele.values,
            strain=range(30),
        ),
        hyperparameters=dict(
            gamma_hyper=0.05,
            pi_hyper=0.1,
            rho_hyper=0.01,
        ),
    )
)

est7, history = sf.workflow.three_stage_fitting(
    model_fit.condition(
        **d.metagenotypes.to_counts_and_totals()
    ),
    stage2_hyperparameters=dict(gamma_hyper=1.1),
    thresh=0.05,
    lagA=20,
    lagB=200,
    opt=pyro.optim.Adamax({"lr": 1e-0}, {"clip_norm": 100}),
    seed=1,
)

sf.plot.plot_loss_history(history)

In [None]:
_world = est7

sf.plot.plot_genotype(
    _world,
    row_colors_func=None,
    col_linkage_func=lambda w: w.metagenotypes.linkage("position"),
    row_linkage_func=lambda w: w.genotypes.linkage("strain"),
    transpose=True,
)

In [None]:
_world = est7

sf.plot.plot_community(
    _world,
    row_colors_func=None,
    col_linkage_func=lambda w: w.metagenotypes.linkage("sample"),
    row_linkage_func=lambda w: w.genotypes.linkage("strain"),
)

In [None]:
sf.plot.plot_beta_diversity_comparison(sim6, est7)

##### Full model

In [None]:
d = sim6

model_fit = (
    sf.model.ParameterizedModel(
        sf.model_zoo.hybrid_fuzzy_missing_dp_betabinomial_metagenotype,
        coords=dict(
            sample=d.sample.values,
            position=d.position.values,
            allele=d.allele.values,
            strain=range(30),
        ),
        hyperparameters=dict(
            gamma_hyper=0.01,
            delta_hyper_r=0.8,
            delta_hyper_temp=0.1,
            rho_hyper=0.01,
            pi_hyper=0.5,
            alpha_hyper_hyper_mean=200.0,
            alpha_hyper_hyper_scale=1.0,
            alpha_hyper_scale=0.5,
            epsilon_hyper_alpha=1.5,
            epsilon_hyper_beta=1.5 / 0.01,
        ),
    )
)

est6, history = sf.workflow.three_stage_fitting(
    model_fit.condition(
        **d.metagenotypes.to_counts_and_totals()
    ),
    stage2_hyperparameters=dict(gamma_hyper=1.1),
    thresh=0.05,
    lagA=20,
    lagB=200,
    opt=pyro.optim.Adamax({"lr": 1e-0}, {"clip_norm": 100}),
    seed=1,
)

sf.plot.plot_loss_history(history)

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))

d = est6

sf.plot.ordination_plot(
    d,
    dmat_func=(
        lambda w:
        pd.DataFrame(
            sp.spatial.distance.squareform(
                sp.spatial.distance.pdist(
                    w.metagenotypes.data.to_dataframe().squeeze().unstack('sample').T,
                    'cosine',
                )
            ),
            index=w.sample, columns=w.sample
        )
    ),
    vmin=0,
#     sizes_func=lambda w: w.data.mu.pipe(np.sqrt) * 10,
#     colors_func=lambda w: w.communities.max('strain'),
    sizes_func=lambda w: w.communities.max('strain')**2 * 75,
    colors_func=lambda w: w.communities.to_pandas().idxmax(1),
#     colors_func=lambda w: w.data.alpha.pipe(np.sqrt),
    cmap=mpl.cm.tab20,
    ax=ax,
)
None

In [None]:
_world = est6

sf.plot.plot_genotype(
    _world,
    row_colors_func=None,
    col_linkage_func=lambda w: w.metagenotypes.linkage("position"),
    row_linkage_func=lambda w: w.genotypes.linkage("strain"),
    transpose=True,
)

In [None]:
_world = est6

sf.plot.plot_missing(
    _world,
    row_colors_func=None,
    col_linkage_func=lambda w: w.metagenotypes.linkage("position"),
    row_linkage_func=lambda w: w.genotypes.linkage("strain"),
    transpose=True,
)

In [None]:
_world = est6

sf.plot.plot_community(
    _world,
    row_colors_func=None,
    col_linkage_func=lambda w: w.metagenotypes.linkage("sample"),
    row_linkage_func=lambda w: w.genotypes.linkage("strain"),
)

In [None]:
_world = est6

sf.plot.plot_metagenotype(
    _world,
    row_colors_func=None,
    col_colors_func=None,
    col_linkage_func=lambda w: w.metagenotypes.linkage("sample"),
    row_linkage_func=lambda w: w.metagenotypes.linkage("position"),
)

## Fit Simulated Data

In [None]:
bins = np.linspace(0.5, 1., num=21)

sample = [7]

d = sim.sel(sample=sample)
plt.hist(d.metagenotypes.dominant_allele_fraction().values.T, bins=bins)
for freq in d.communities.values.squeeze():
    plt.axvline(1 - freq, color='k', lw=1, linestyle='--')
plt.xlim(0.5, 1.0)

### Full Model

In [None]:
d = sim6

model_fit = (
    sf.model.ParameterizedModel(
        sf.model_zoo.hybrid_fuzzy_missing_dp_betabinomial_metagenotype,
        coords=dict(
            sample=d.sample.values,
            position=d.position.values,
            allele=d.allele.values,
            strain=range(30),
        ),
        hyperparameters=dict(
            gamma_hyper=0.01,
            delta_hyper_r=0.8,
            delta_hyper_temp=0.1,
            rho_hyper=0.01,
            pi_hyper=0.5,
            alpha_hyper_hyper_mean=200.0,
            alpha_hyper_hyper_scale=1.0,
            alpha_hyper_scale=0.5,
            epsilon_hyper_alpha=1.5,
            epsilon_hyper_beta=1.5 / 0.01,
        ),
    )
)

est1, history = sf.workflow.three_stage_fitting(
    model_fit.condition(
        **d.metagenotypes.to_counts_and_totals()
    ),
    stage2_hyperparameters=dict(gamma_hyper=1.0),
    cluster_kwargs=dict(thresh=0.05),
    lagA=20,
    lagB=200,
    opt=pyro.optim.Adamax({"lr": 1e-0}, {"clip_norm": 100}),
    seed=1,
)

sf.plot.plot_loss_history(history)

In [None]:
print(sf.evaluation.weighted_genotype_error(sim, est1), sf.evaluation.community_error(sim, est1))

In [None]:
sf.plot.plot_community(
    sf.data.World.concat(
        {
            'sim': sim6,
            'est': est6
        },
        dim='strain'),
    row_colors_func=lambda w: xr.Dataset(dict(
        abundance=w.communities.mean("sample"),
        entropy=w.genotypes.entropy,
        which_fit=w.data['_concat_from'].to_series().map({'sim': 1, 'est': 0}).to_xarray(),
    )),
    norm=None,
)

In [None]:
sf.plot.plot_masked_genotype(
    sf.data.World.concat(
        {
            'sim': sim6,
            'est': est6,
        },
        dim='strain'),
    row_colors_func=lambda w: xr.Dataset(dict(
        abundance=w.communities.mean("sample"),
        entropy=w.genotypes.entropy,
        which_fit=w.data['_concat_from'].to_series().map({'sim': 0, 'est': 1}).to_xarray(),
    )),
)

In [None]:
sf.plot.plot_masked_genotype(
    sf.data.World.concat(
        {
            'sim': sim,
            'est': est1
        },
        dim='strain'),
    row_colors_func=lambda w: xr.Dataset(dict(
        abundance=w.communities.mean("sample"),
        entropy=w.genotypes.entropy,
        which_fit=w.data['_concat_from'].to_series().map({'sim': 0, 'est': 1}).to_xarray(),
    )),
)

In [None]:
sf.plot.plot_missing(
    sf.data.World.concat(
        {
            'sim': sim,
            'est': est1
        },
        dim='strain'),
    col_colors_func=lambda w: xr.Dataset(dict(
        abundance=w.communities.mean("sample"),
        entropy=w.genotypes.entropy,
        which_fit=w.data['_concat_from'].to_series().map({'sim': 0, 'est': 1}).to_xarray(),
    )),
)

In [None]:
plt.scatter(sim6.data.mu, est6.data.mu, c=sim6.metagenotypes.sum('allele').mean('position'))

In [None]:
plt.scatter(sim6.data.m_hyper_r, est6.data.m_hyper_r)

In [None]:
plt.scatter(sim6.data.epsilon, est6.data.epsilon, c=sim6.data.mu, alpha=0.7)

## Fit Real Data

In [None]:
# Sanity check on sfacts/data.py
obs = (
    sf.data.Metagenotypes.load('data/ucfmt.sp-100022.gtpro-pileup.nc')
    .select_variable_positions(incid_thresh=0.2)
    .select_samples_with_coverage(0.1)
    .to_world()
)

obs.metagenotypes.to_estimated_genotypes().validate_constraints()

print(obs.sizes)
sf.plot.plot_metagenotype(
    (
        obs
#         .isel(position=range(1000))
    ),
    col_colors_func=(
        lambda w: (
            w
            .metagenotypes
            .sum('allele')
            .mean('position')
            .pipe(np.sqrt)
            .rename('mean_depth')
        )
    ),
)

### Toy Model

In [None]:
d = obs

model_fit = (
    sf.model.ParameterizedModel(
        sf.model_zoo.simple_metagenotype2,
        coords=dict(
            sample=d.sample.values,
            position=d.position.values,
            allele=d.allele.values,
            strain=range(50),
        ),
        hyperparameters=dict(
            gamma_hyper=0.05,
            rho_hyper=0.01,
            pi_hyper=0.1,
        ),
    )
)

est8, history = sf.workflow.three_stage_fitting(
    model_fit.condition(
        **d.metagenotypes.to_counts_and_totals()
    ),
    stage2_hyperparameters=dict(gamma_hyper=1.1),
    cluster_kwargs=dict(thresh=0.05),
    lagA=20,
    lagB=200,
    opt=pyro.optim.Adamax({"lr": 1e-0}, {"clip_norm": 100}),
    seed=1,
)

sf.plot.plot_loss_history(history)

In [None]:
_world = est8

sf.plot.plot_community(
    est8,
    row_colors_func=lambda w: xr.Dataset(dict(
        abundance=w.communities.mean("sample").pipe(np.sqrt),
        max_abund=w.communities.max("sample"),
        entropy=w.genotypes.entropy(),
#         missing=1 - w.missingness.mean("position"),
#         which_fit=w.data['_concat_from'].to_series().map({'est': 1}).to_xarray(),
    )),
    col_colors_func=lambda w: xr.Dataset(dict(
#         mu=w.data.mu.pipe(np.sqrt),
        mean_depth=w.metagenotypes.sum("allele").mean("position").pipe(np.sqrt),
        max_frac=w.communities.max('strain').rename('max_frac'),
        entropy=w.metagenotypes.entropy(),
    )),
#     norm=None,
)

In [None]:
_world = est8

sf.plot.plot_genotype(
    _world,
    row_colors_func=lambda w: xr.Dataset(dict(
        abundance=w.communities.mean("sample"),
        entropy=w.genotypes.entropy(),
        max_frac=w.communities.max("sample"),
    )),
)

In [None]:
_world = est8

sf.plot.plot_metagenotype(
    est8,
    col_linkage_func=lambda w: w.communities.linkage("sample"),
#     row_colors_func=lambda w: xr.Dataset(dict(
#         abundance=w.communities.mean("sample"),
#         entropy=w.genotypes.entropy,
# #         missing=1 - w.missingness.mean("position"),
# #         which_fit=w.data['_concat_from'].to_series().map({'est': 1}).to_xarray(),
#     )),
    col_colors_func=lambda w: xr.Dataset(dict(
#         mu=w.data.mu.pipe(np.sqrt),
        mean_depth=w.metagenotypes.sum("allele").mean("position").pipe(np.sqrt),
        entropy=w.metagenotypes.entropy(),
        max_frac=w.communities.max('strain').rename('max_frac'),
    )),
#     norm=None,
)

In [None]:
d = obs

model_fit = (
    sf.model.ParameterizedModel(
        sf.model_zoo.simple_metagenotype,
        coords=dict(
            sample=d.sample.values,
            position=d.position.values,
            allele=d.allele.values,
            strain=range(30),
        ),
        hyperparameters=dict(
            gamma_hyper=0.005,
            rho_hyper=0.01,
            pi_hyper=0.5,
        ),
    )
)

# est9, history = sf.workflow.three_stage_fitting(
#     model_fit.condition(
#         **d.metagenotypes.to_counts_and_totals()
#     ),
#     stage2_hyperparameters=dict(gamma_hyper=1.0),
#     cluster_kwargs=dict(thresh=0.02),
#     lagA=20,
#     lagB=200,
#     opt=pyro.optim.Adamax({"lr": 1e-0}, {"clip_norm": 100}),
#     seed=1,
# )

est9, history = sf.estimation.estimate_parameters(
    model_fit.condition(
        **d.metagenotypes.to_counts_and_totals()
    ),
    lagA=20,
    lagB=200,
    opt=pyro.optim.Adamax({"lr": 1e-0}, {"clip_norm": 100}),
    seed=1,
)

sf.plot.plot_loss_history(history)

In [None]:
_world = est9

sf.plot.plot_community(
    _world,
    row_colors_func=lambda w: xr.Dataset(dict(
        abundance=w.communities.mean("sample").pipe(np.sqrt),
        max_abund=w.communities.max("sample"),
        entropy=w.genotypes.entropy,
#         missing=1 - w.missingness.mean("position"),
#         which_fit=w.data['_concat_from'].to_series().map({'est': 1}).to_xarray(),
    )),
    col_colors_func=lambda w: xr.Dataset(dict(
#         mu=w.data.mu.pipe(np.sqrt),
        mean_depth=w.metagenotypes.sum("allele").mean("position").pipe(np.sqrt),
#         alpha=w.data.alpha.pipe(np.sqrt),
        max_frac=w.communities.max('strain').rename('max_frac'),
    )),
#     norm=None,
)

In [None]:
_world = est9

sf.plot.plot_genotype(
    _world,
    row_colors_func=lambda w: xr.Dataset(dict(
        abundance=w.communities.mean("sample"),
        max_abund=w.communities.max("sample"),
        entropy=w.genotypes.entropy,
    )),
)

In [None]:
_world = est9

sf.plot.plot_metagenotype(
    est9,
    col_linkage_func=lambda w: w.communities.linkage("sample"),
    col_colors_func=lambda w: xr.Dataset(dict(
        mean_depth=w.metagenotypes.sum("allele").mean("position").pipe(np.sqrt),
        max_frac=w.communities.max('strain').rename('max_frac'),
        entropy=w.metagenotypes.entropy(),
    )),
    row_linkage_func=lambda w: w.genotypes.linkage("position")
)

In [None]:
_world = est9

sf.plot.plot_depth(
    est9,
    col_linkage_func=lambda w: w.communities.linkage("sample"),
    col_colors_func=lambda w: xr.Dataset(dict(
        mean_depth=w.metagenotypes.sum("allele").mean("position").pipe(np.sqrt),
        max_frac=w.communities.max('strain').rename('max_frac'),
    )),
    row_linkage_func=lambda w: w.genotypes.linkage("position")
)

In [None]:
_est = est9

model_sim = (
    sf.model.ParameterizedModel(
        sf.model_zoo.simple_metagenotype,
        coords=dict(
            sample=_est.sample.values,
            position=_est.position.values,
            allele=_est.allele.values,
            strain=_est.strain.values,
        ),
        hyperparameters=dict(
            gamma_hyper=0.01,
            rho_hyper=0.01,
            pi_hyper=0.5,
        ),
        data=dict(
            pi=_est.communities.values,
            gamma=_est.genotypes.values.round(),
            m=_est.metagenotypes.sum("allele").values,
        )
    )
)

sim8 = model_sim.simulate_world(seed=1)

In [None]:
# sample = ['SS01052']
# sample = ['SS01068']
sample = ['SS01097']

fig, ax = plt.subplots()

for _world, c, show_dominant in zip([obs, sim8], ['darkblue', 'darkgreen'], [False, True]):
    sf.plot.plot_metagenotype_frequency_spectrum(_world, sample_list=sample, axs=ax, color=c, alpha=0.5)

### Full Model

In [None]:
d = obs#.isel(position=range(500))

model_fit2 = (
    sf.model.ParameterizedModel(
        sf.model_zoo.hybrid_fuzzy_missing_dp_betabinomial_metagenotype,
        coords=dict(
            sample=d.sample.values,
            position=d.position.values,
            allele=d.allele.values,
            strain=range(30),
        ),
        hyperparameters=dict(
            gamma_hyper=0.01,
            delta_hyper_r=0.8,
            delta_hyper_temp=0.1,
            rho_hyper=0.01,
            pi_hyper=0.5,
            alpha_hyper_hyper_mean=200.0,
            alpha_hyper_hyper_scale=1.0,
            alpha_hyper_scale=0.5,
            epsilon_hyper_alpha=1.5,
            epsilon_hyper_beta=1.5 / 0.01,
        ),
    )
)

est2, history = sf.workflow.three_stage_fitting(
    model_fit2.condition(
        **d.metagenotypes.to_counts_and_totals()
    ),
    stage2_hyperparameters=dict(gamma_hyper=1.1),
    lagA=20,
    lagB=200,
    opt=pyro.optim.Adamax({"lr": 1e-0}, {"clip_norm": 100}),
    seed=1,
)

sf.plot.plot_loss_history(history)

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))

_world = est2

sf.plot.ordination_plot(
    _world,
    dmat_func=lambda w: w.metagenotypes.cosine_pdist(),
    vmin=0,
#     sizes_func=lambda w: w.data.mu.pipe(np.sqrt) * 10,
#     colors_func=lambda w: w.communities.max('strain'),
#     sizes_func=lambda w: w.communities.max('strain')**2 * 75,
#     colors_func=lambda w: w.communities.to_pandas().idxmax(1),
#     colors_func=lambda w: w.data.alpha.pipe(np.sqrt),
    cmap=mpl.cm.tab20,
    ax=ax,
)
None

In [None]:
plt.scatter(est2.metagenotypes.sum('allele').mean('position'), est3.data.mu)

In [None]:
plt.hist(est2.data.epsilon)

In [None]:
sf.plot.plot_community(
    est2,
    row_colors_func=lambda w: xr.Dataset(dict(
        abundance=w.communities.mean("sample"),
        entropy=w.genotypes.entropy(),
        missing=1 - w.missingness.mean("position"),
#         which_fit=w.data['_concat_from'].to_series().map({'est': 1}).to_xarray(),
    )),
    col_colors_func=lambda w: xr.Dataset(dict(
        mu=w.data.mu.pipe(np.sqrt),
        alpha=w.data.alpha.pipe(np.sqrt),
        max_frac=w.communities.max('strain').rename('max_frac'),
    )),
#     norm=None,
#     transpose=True,
)

In [None]:
bins = np.linspace(0.5, 1., num=21)
d = est2

sample = ['DS0097_035']
# sample = ['SS01105']

d = d.sel(sample=sample)
plt.hist(d.metagenotypes.dominant_allele_fraction().values.T, bins=bins)
for freq in d.communities.values.squeeze():
    plt.axvline(1 - freq, color='k', lw=1, linestyle='--')
plt.xlim(0.5, 1.0)

In [None]:
_world = est2

sf.plot.plot_genotype(
    _world,
    row_colors_func=None,
    col_linkage_func=lambda w: w.metagenotypes.linkage("position"),
    row_linkage_func=lambda w: w.genotypes.linkage("strain"),
    transpose=True,
)

In [None]:
_world = est2

sf.plot.plot_missing(
    _world,
    row_colors_func=None,
    col_linkage_func=lambda w: w.metagenotypes.linkage("position"),
    row_linkage_func=lambda w: w.genotypes.linkage("strain"),
    transpose=True,
)

In [None]:
_world = est2

sf.plot.plot_community(
    _world,
    row_colors_func=None,
    col_linkage_func=lambda w: w.communities.linkage("sample"),
    row_linkage_func=lambda w: w.genotypes.linkage("strain"),
)

In [None]:
_world = est2

sf.plot.plot_metagenotype(
    _world,
    row_colors_func=None,
    col_colors_func=None,
    col_linkage_func=lambda w: w.communities.linkage("sample"),
    row_linkage_func=lambda w: w.metagenotypes.linkage("position"),
)

In [None]:
_est = est2

model_sim = (
    sf.model.ParameterizedModel(
        sf.model_zoo.simple_metagenotype,
        coords=dict(
            sample=_est.sample.values,
            position=_est.position.values,
            allele=_est.allele.values,
            strain=_est.strain.values,
        ),
        hyperparameters=dict(
            gamma_hyper=0.01,
            rho_hyper=0.01,
            pi_hyper=0.5,
        ),
        data=dict(
            pi=_est.communities.values,
            gamma=_est.genotypes.values,  #.round(),
            m=_est.metagenotypes.sum("allele").values,
        )
    )
)

sim9 = model_sim.simulate_world(seed=1)

In [None]:
sample = ['DS0097_035']
# sample = ['SS01052']
# sample = ['SS01068']
# sample = ['SS01097']

fig, ax = plt.subplots()

for _world, c in zip([obs, sim9], ['darkblue', 'darkgreen', 'purple']):
    sf.plot.plot_metagenotype_frequency_spectrum(_world, sample_list=sample, axs=ax, color=c, alpha=0.25)