## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from sfacts.data import load_input_data, select_informative_positions
import numpy as np
from sfacts.logging_util import info
from sfacts.pandas_util import idxwhere
from sfacts.workflow import fit_to_data
import sfacts as sf
import matplotlib as mpl
import matplotlib.pyplot as plt

## UCFMT

### 100022 (F. prausnitzii)

In [None]:
mrg_ss, data_fit, history = sf.workflow.fit_from_files(
    ['data/ucfmt.sp-100022.gtpro-pileup.nc'],
    incid_thresh=0.1,
    cvrg_thresh=0.05,
    npos=2369,
    seed=2,
    preclust_kwargs=dict(
        thresh=0.1,
        additional_strains_factor=0.,
        progress=True,
    ),
    fit_kwargs=dict(
        gamma_hyper=0.01,
        pi_hyper=1.0,
        rho_hyper=0.5,
        mu_hyper_mean=5,
        mu_hyper_scale=5.,
        m_hyper_r=10.,
        delta_hyper_temp=0.1,
        delta_hyper_p=0.9,
        alpha_hyper_hyper_mean=100.,
        alpha_hyper_hyper_scale=10.,
        alpha_hyper_scale=0.5,
        epsilon_hyper_alpha=1.5,
        epsilon_hyper_beta=1.5 / 0.01,
        device='cuda',
        lag=100,
        lr=1e-0,
        progress=True
    ),
    postclust_kwargs=dict(
        thresh=0.1,
    ),
)

In [None]:
sf.plot.plot_loss_history(history)

In [None]:
sf.plot.plot_genotype(sf.genotype.counts_to_p_estimate(
    data_fit.sel(allele='alt').values, data_fit.sum('allele').values
))

In [None]:
sf.plot.plot_genotype(mrg_ss['gamma'])

In [None]:
sf.plot.plot_missing(mrg_ss['delta'])

In [None]:
sf.plot.plot_community(
    mrg_ss['pi'],
    yticklabels=1,
    row_colors=mpl.cm.viridis(np.log10(mrg_ss['alpha'])),
    col_colors=mpl.cm.viridis(sf.evaluation.mean_masked_genotype_entropy(mrg_ss['gamma'], mrg_ss['delta'])),
    norm=mpl.colors.PowerNorm(1/3),
)

In [None]:
import matplotlib.pyplot as plt

plt.hist(np.log10(mrg_ss['alpha']))

### 102506 (Escherichia)

In [None]:
mrg_ss, data_fit, history = sf.workflow.fit_from_files(
    ['data/ucfmt.sp-102506.gtpro-pileup.nc'],
    incid_thresh=0.1,
    cvrg_thresh=0.05,
    npos=10000,
    preclust_kwargs=dict(
        thresh=0.1,
        additional_strains_factor=0.,
        progress=True,
    ),
    fit_kwargs=dict(
        gamma_hyper=0.01,
        pi_hyper=0.5,
        rho_hyper=0.5,
        mu_hyper_mean=5,
        mu_hyper_scale=5.,
        m_hyper_r=10.,
        delta_hyper_temp=0.1,
        delta_hyper_p=0.9,
        alpha_hyper_hyper_mean=100.,
        alpha_hyper_hyper_scale=10.,
        alpha_hyper_scale=0.5,
        epsilon_hyper_alpha=1.5,
        epsilon_hyper_beta=1.5 / 0.01,
        device='cuda',
        lag=100,
        lr=2e-0,
        progress=True
    ),
    postclust_kwargs=dict(
        thresh=0.1,
    ),
    seed=2,
)

In [None]:
sf.plot.plot_loss_history(history)

In [None]:
sf.plot.plot_community(
    mrg_ss['pi'],
    yticklabels=1,
    row_colors=mpl.cm.viridis(np.log10(mrg_ss['alpha'])),
    col_colors=mpl.cm.viridis(sf.evaluation.mean_masked_genotype_entropy(mrg_ss['gamma'], mrg_ss['delta'])),
    norm=mpl.colors.PowerNorm(1/3, vmin=0, vmax=1),
)

In [None]:
import matplotlib.pyplot as plt

plt.hist(np.log10(mrg_ss['alpha']))

## All MGEN

### 100022 (F. prausnitzii)

In [None]:
mrg_ss, data_fit, history = sf.workflow.fit_from_files(
    ['data/core.sp-100022.gtpro-pileup.nc'],
    incid_thresh=0.1,
    cvrg_thresh=0.5,
    npos=500,
    preclust=False,
#     preclust_kwargs=dict(
#         thresh=0.1,
#         additional_strains_factor=0.,
#         progress=True,
#     ),
    fit_kwargs=dict(
        s=1400,
        gamma_hyper=0.01,
        pi_hyper=0.001,
        rho_hyper=0.1,
        mu_hyper_mean=5,
        mu_hyper_scale=5.,
        m_hyper_r=10.,
        delta_hyper_temp=0.1,
        delta_hyper_p=0.9,
#         alpha_hyper_hyper_mean=100.,
#         alpha_hyper_hyper_scale=10.,
#         alpha_hyper_scale=0.5,
        alpha_hyper_hyper_mean=1000.,
        alpha_hyper_hyper_scale=0.001,
        alpha_hyper_scale=0.001,
        epsilon_hyper_alpha=1.5,
        epsilon_hyper_beta=1.5 / 0.01,
        device='cuda',
        lag=100,
        lr=1e-1,
        progress=True
    ),
    postclust_kwargs=dict(
        thresh=0.1,
        progress=True,
    ),
    seed=1,
)

In [None]:
sf.plot.plot_loss_history(history)

In [None]:
plt.hist(np.log10(mrg_ss['alpha']), bins=100)
#plt.yscale('log')
None

In [None]:
plt.hist(np.log10(mrg_ss['epsilon']), bins=100)
None

In [None]:
nsamples = 200
top_strains = mrg_ss['pi'].sum(0).argsort()[-50:]

sf.plot.plot_community(
    mrg_ss['pi'][:nsamples, top_strains],
    yticklabels=1,
    row_colors=mpl.cm.viridis(np.log10(mrg_ss['alpha'][:nsamples])),
    col_colors=mpl.cm.viridis(sf.evaluation.mean_masked_genotype_entropy(mrg_ss['gamma'], mrg_ss['delta'])[top_strains]),
    norm=mpl.colors.PowerNorm(1/3),
)

In [None]:
grid = sf.plot.plot_genotype(
    mrg_ss['gamma'][top_strains],
    col_colors=mpl.cm.viridis(sf.evaluation.mean_masked_genotype_entropy(mrg_ss['gamma'], mrg_ss['delta'])[top_strains]),
)

In [None]:
grid = sf.plot.plot_missing(
    mrg_ss['delta'][top_strains],
    col_colors=mpl.cm.viridis(sf.evaluation.mean_masked_genotype_entropy(mrg_ss['gamma'], mrg_ss['delta'])[top_strains]),
)

In [None]:
sf.plot.plot_genotype(
    mrg_ss['gamma'],
    col_colors=mpl.cm.viridis(sf.evaluation.mean_masked_genotype_entropy(mrg_ss['gamma'], mrg_ss['delta'])),
)

In [None]:
sf.plot.plot_missing(mrg_ss['delta'])