## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from sfacts.data import load_input_data, select_informative_positions
import numpy as np
from sfacts.logging_util import info
from sfacts.pandas_util import idxwhere
from sfacts.workflow import fit_to_data
import sfacts as sf
import matplotlib as mpl
import matplotlib.pyplot as plt
import warnings
import torch
import pandas as pd

In [None]:
warnings.filterwarnings(
    "ignore",
    message="torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.",
    category=torch.jit.TracerWarning,
#     module="trace_elbo",  # FIXME: What is the correct regex for module?
#     lineno=5,
)

## UCFMT

### 102506 (Escherichia)

In [None]:
info("Loading input data.")
data = load_input_data(['data/ucfmt.sp-102506.gtpro-pileup.nc'])

mrg_ss, data_fit, history = sf.workflow.filter_subsample_and_fit(
    data,
    incid_thresh=0.2,
    cvrg_thresh=0.05,
    npos=5000,
    preclust_kwargs=dict(
        thresh=0.1,
        additional_strains_factor=0.,
        progress=True,
    ),
    fit_kwargs=dict(
        gamma_hyper=0.01,
        pi_hyper=0.5,
        rho_hyper=0.5,
        mu_hyper_mean=5,
        mu_hyper_scale=5.,
        m_hyper_r=10.,
        delta_hyper_temp=0.1,
        delta_hyper_p=0.9,
        alpha_hyper_hyper_mean=100.,
        alpha_hyper_hyper_scale=10.,
        alpha_hyper_scale=0.5,
        epsilon_hyper_alpha=1.5,
        epsilon_hyper_beta=1.5 / 0.01,
        device='cuda',
        lag=100,
        lr=1e-0,
        progress=True
    ),
    postclust_kwargs=dict(
        thresh=0.1,
    ),
    seed=1,
)

In [None]:
sf.plot.plot_loss_history(history)

In [None]:
sf.plot.plot_genotype(
    sf.genotype.counts_to_p_estimate(
        data_fit.sel(allele='alt').values, data_fit.sum('allele').values,
    ),
    col_colors=mpl.cm.viridis(np.log10(mrg_ss['alpha'])),
)

In [None]:
sf.plot.plot_genotype(mrg_ss['gamma'])

In [None]:
sf.plot.plot_missing(mrg_ss['delta'])

In [None]:
sf.plot.plot_community(
    mrg_ss['pi'],
    yticklabels=1,
    row_colors=mpl.cm.viridis(np.log10(mrg_ss['alpha'])),
    col_colors=mpl.cm.viridis(sf.evaluation.mean_masked_genotype_entropy(mrg_ss['gamma'], mrg_ss['delta'])),
    norm=mpl.colors.PowerNorm(1/3),
)

In [None]:
import matplotlib.pyplot as plt

plt.hist(np.log10(mrg_ss['alpha']))

#### Associate with metadat

In [None]:
sf.plot.plot_community(
    pd.DataFrame(mrg_ss['pi'], index=data_fit.library_id),
    yticklabels=1,
    row_colors=mpl.cm.viridis(np.log10(mrg_ss['alpha'])),
    col_colors=mpl.cm.viridis(sf.evaluation.mean_masked_genotype_entropy(mrg_ss['gamma'], mrg_ss['delta'])),
    norm=mpl.colors.PowerNorm(1/3),
)

### 100022 (F. prausnitzii)

In [None]:
info("Loading input data.")
data = load_input_data(['data/ucfmt.sp-100022.gtpro-pileup.nc'])

mrg_ss, data_fit, history = sf.workflow.filter_subsample_and_fit(
    data,
    incid_thresh=0.1,
    cvrg_thresh=0.05,
    npos=5000,
    preclust_kwargs=dict(
        thresh=0.1,
        additional_strains_factor=0.,
        progress=True,
    ),
    fit_kwargs=dict(
        gamma_hyper=0.01,
        pi_hyper=0.5,
        rho_hyper=0.5,
        mu_hyper_mean=5,
        mu_hyper_scale=5.,
        m_hyper_r=10.,
        delta_hyper_temp=0.1,
        delta_hyper_p=0.9,
        alpha_hyper_hyper_mean=100.,
        alpha_hyper_hyper_scale=10.,
        alpha_hyper_scale=0.5,
        epsilon_hyper_alpha=1.5,
        epsilon_hyper_beta=1.5 / 0.01,
        device='cuda',
        lag=100,
        lr=1e-0,
        progress=True
    ),
    postclust_kwargs=dict(
        thresh=0.1,
    ),
    seed=1,
)

In [None]:
sf.plot.plot_loss_history(history)

In [None]:
sf.plot.plot_genotype(
    sf.genotype.counts_to_p_estimate(
        data_fit.sel(allele='alt').values, data_fit.sum('allele').values,
    ),
    col_colors=mpl.cm.viridis(np.log10(mrg_ss['alpha'])),
)

In [None]:
sf.plot.plot_genotype(mrg_ss['gamma'])

In [None]:
sf.plot.plot_missing(mrg_ss['delta'])

In [None]:
sf.plot.plot_community(
    mrg_ss['pi'],
    yticklabels=1,
    row_colors=mpl.cm.viridis(np.log10(mrg_ss['alpha'])),
    col_colors=mpl.cm.viridis(sf.evaluation.mean_masked_genotype_entropy(mrg_ss['gamma'], mrg_ss['delta'])),
    norm=mpl.colors.PowerNorm(1/3),
)

In [None]:
import matplotlib.pyplot as plt

plt.hist(np.log10(mrg_ss['alpha']))

#### Associate with metadat

In [None]:
sf.plot.plot_community(
    pd.DataFrame(mrg_ss['pi'], index=data_fit.library_id),
    yticklabels=1,
    row_colors=mpl.cm.viridis(np.log10(mrg_ss['alpha'])),
    col_colors=mpl.cm.viridis(sf.evaluation.mean_masked_genotype_entropy(mrg_ss['gamma'], mrg_ss['delta'])),
    norm=mpl.colors.PowerNorm(1/3),
)