## Preamble

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sfacts as sf
import pyro
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
import warnings
import pandas as pd
import scipy as sp
import torch

In [None]:
warnings.filterwarnings(
    "ignore",
    message="torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.",
    category=torch.jit.TracerWarning,
#     module="trace_elbo",  # FIXME: What is the correct regex for module?
#     lineno=5,
)

## Simulation

In [None]:
# def simulation_benchmark():
structure_sim=sf.model_zoo.full_metagenotype_dirichlet_rho.full_metagenotype_dirichlet_rho_model_structure
coords=dict(
    sample=1000,
    position=500,
    allele=['alt', 'ref'],
)
nstrain_sim=200
hyperparameters_sim=dict(
    gamma_hyper=1e-3,
    delta_hyper_temp=1e-3,
    delta_hyper_r=0.9,
    rho_hyper=5.0,
    pi_hyper=0.2,
    mu_hyper_mean=1.0,
    mu_hyper_scale=1.,
    epsilon_hyper_mode=0.01,
    epsilon_hyper_spread=1.5,
    alpha_hyper_hyper_mean=10.0,
    alpha_hyper_hyper_scale=0.5,
    alpha_hyper_scale=0.5,
)
condition_on_sim=dict(
    m_hyper_r_mean=5,
    m_hyper_r_scale=1,
)
device='cpu'
dtype=torch.float32


coords_sim = coords.copy()
coords_sim.update({'strain': nstrain_sim})
model_sim = sf.model.ParameterizedModel(
    structure_sim,
    coords=coords_sim,
    hyperparameters=hyperparameters_sim,
    data=condition_on_sim,
    device=device,
    dtype=dtype,
)

world_sim = model_sim.simulate_world()

In [None]:
sf.plot.plot_community(world_sim)

## Fitting

In [None]:
# def simulation_benchmark():
structure_fit=sf.model_zoo.full_metagenotype_dirichlet_rho.full_metagenotype_dirichlet_rho_model_structure
nstrain_fit=300
hyperparameters_fit=dict(
    gamma_hyper=0.5,
    delta_hyper_temp=0.1,
    delta_hyper_r=0.9,
    rho_hyper=0.01,
    pi_hyper=0.5,
    mu_hyper_mean=10.0,
    mu_hyper_scale=10.,
    epsilon_hyper_mode=0.01,
    epsilon_hyper_spread=1.5,
    alpha_hyper_hyper_mean=100.0,
    alpha_hyper_hyper_scale=1.,
    alpha_hyper_scale=0.5,
)
stage2_hyperparameters=dict(
    gamma_hyper=1.0,
)
condition_on_fit=dict(
)


coords_fit = coords.copy()
coords_fit.update({'strain': nstrain_fit})
model_fit = sf.model.ParameterizedModel(
    structure_fit,
    coords=coords_fit,
    hyperparameters=hyperparameters_fit,
    data=condition_on_fit,

)

nposition_fit = 500

world_fit = sf.workflow.fit_metagenotype_subsample_collapse_then_iteratively_refit_full_genotypes(
    structure_fit,
    world_sim.metagenotypes.random_sample(nposition_fit, 'position'),
    nstrain=nstrain_fit,
    nposition=nposition_fit,
    hyperparameters=hyperparameters_fit,
    stage2_hyperparameters=stage2_hyperparameters,
    thresh=0.01,
    condition_on=condition_on_fit,
    device=device,
    dtype=dtype,
    estimation_kwargs=dict(
        jit=True,
        maxiter=10000,
        lagA=20,
        lagB=100,
        opt=pyro.optim.Adamax({"lr": 1e-0}, {"clip_norm": 100}),
    )
)

In [None]:
sf.plot.plot_community(
    sf.data.World.concat(
        dict(
            fit=world_fit,
#             sim=world_sim.sel(position=world_fit.position),
        ), dim='strain', rename_coords=True,
    ),
    col_linkage_func=lambda w: sf.data.latent_metagenotypes_linkage(w)
)

In [None]:
sf.plot.plot_genotype(
    sf.data.World.concat(
        dict(
            fit=world_fit,
#             sim=world_sim.sel(position=world_fit.position),
        ), dim='strain', rename_coords=True,
    ),
    col_linkage_func=lambda w: w.metagenotypes.linkage('position'),
)

In [None]:
sf.evaluation.community_error(world_sim, world_fit)