In [None]:
%load_ext autoreload
%autoreload 2

## Imports

In [None]:
from sfacts.data import load_input_data, select_informative_positions
import numpy as np
from sfacts.logging_util import info
from sfacts.pandas_util import idxwhere
from sfacts.workflow import fit_to_data
import sfacts as sf

In [None]:
inpath = ['data/ucfmt.sp-100022.gtpro-pileup.nc']
incid_thresh = 0.1
cvrg_thresh = 0.15
npos = 2000


info("Loading input data.")
data = load_input_data(inpath)
info(f"Full data shape: {data.sizes}.")

info("Filtering positions.")
informative_positions = select_informative_positions(
    data, incid_thresh
)
npos_available = len(informative_positions)
info(
    f"Found {npos_available} informative positions with minor "
    f"allele incidence of >{incid_thresh}"
)

npos = min(npos, npos_available)
info(f"Randomly sampling {npos} positions.")
position_ss = np.random.choice(
    informative_positions,
    size=npos,
    replace=False,
)

info("Filtering libraries.")
suff_cvrg_samples = idxwhere(
    (
        (
            data.sel(position=informative_positions).sum(["allele"]) > 0
        ).mean("position")
        > cvrg_thresh
    ).to_series()
)
nlibs = len(suff_cvrg_samples)
info(
    f"Found {nlibs} libraries with >{cvrg_thresh:0.1%} "
    f"of informative positions covered."
)

In [None]:
info("Constructing input data.")
data_fit = data.sel(library_id=suff_cvrg_samples, position=position_ss)
m_ss = data_fit.sum("allele")
n, g_ss = m_ss.shape
y_obs_ss = data_fit.sel(allele="alt")

In [None]:
sf.plot.plot_genotype(sf.genotype.counts_to_p_estimate(y_obs_ss.values, m_ss.values))

In [None]:
seed = 1

info("Optimizing model parameters.")
mrg_ss, fit_ss, history = fit_to_data(
        y_obs_ss.values,
        m_ss.values,
        preclust_kwargs=dict(
            thresh=0.05,
            additional_strains_factor=0.,
            progress=False,
        ),
        fit_kwargs=dict(
            gamma_hyper=0.01,
            pi_hyper=1.0,
            rho_hyper=0.5,
            mu_hyper_mean=5,
            mu_hyper_scale=5.,
            m_hyper_r=10.,
            delta_hyper_temp=0.1,
            delta_hyper_p=0.9,
            alpha_hyper_hyper_mean=100.,
            alpha_hyper_hyper_scale=10.,
            alpha_hyper_scale=0.5,
            epsilon_hyper_alpha=1.5,
            epsilon_hyper_beta=1.5 / 0.01,
            device='cpu',
            lag=100,
            lr=2e-0,
            progress=True
        ),
        postclust_kwargs=dict(
            thresh=0.1,
        ),
        seed=seed,
    )

In [None]:
sf.plot.plot_loss_history(history)

In [None]:
sf.plot.plot_genotype(mrg_ss['gamma'])

In [None]:
sf.plot.plot_missing(mrg_ss['delta'])

In [None]:
sf.plot.plot_community(mrg_ss['pi'], yticklabels=1)

In [None]:
import matplotlib.pyplot as plt

plt.hist(np.log10(mrg_ss['alpha']))