## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sfacts as sf

In [None]:
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import scipy as sp
import pyro
import pyro.distributions as dist
import torch
from functools import partial
from tqdm import tqdm
import xarray as xr
import warnings
from torch.jit import TracerWarning

In [None]:
warnings.filterwarnings(
    "ignore",
    message="torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.",
    category=torch.jit.TracerWarning,
#     module="trace_elbo",  # FIXME: What is the correct regex for module?
#     lineno=5,
)

## Library

In [None]:
!tree sfacts -I __pycache__

### `__init__.py`

### pyro_util.py

### data.py

### plot.py

### model.py

### model_zoo.py

### estimation.py

### evaluation.py

### workflow.py

## Prototype

In [None]:
# Sanity check on sfacts/data.py
np.random.seed(1)

obs_all = (
    sf.data.Metagenotypes.load('data/ucfmt.sp-100022.gtpro-pileup.nc')
    .select_variable_positions(incid_thresh=0.2)
    .select_samples_with_coverage(0.1)
    .to_world()
)

obs = obs_all.random_sample(1500, 'position')

# Test .validate_constraints()
obs.metagenotypes.to_estimated_genotypes().validate_constraints()

sf.plot.plot_metagenotype(
    obs
)

In [None]:
s = 50

In [None]:
approx = sf.estimation.nmf_approximation(obs, s, random_state=1, alpha=0., solver='cd', init='random', tol=1e-4)

In [None]:
sf.plot.plot_genotype(approx)
sf.plot.plot_community(approx)

In [None]:
d = obs

model_fit = sf.model.ParameterizedModel(
    sf.model_zoo.full_metagenotype_model_structure,
    coords=dict(
        sample=d.sample.values,
        position=d.position.values,
        allele=d.allele.values,
        strain=range(s),
    ),
    hyperparameters=dict(
        gamma_hyper=0.1,
        delta_hyper_r=0.8,
        delta_hyper_temp=0.1,
        rho_hyper=0.01,
        pi_hyper=0.5,
        alpha_hyper_hyper_mean=1000.0,
        alpha_hyper_hyper_scale=0.5,
        alpha_hyper_scale=1.0,
        epsilon_hyper_alpha=1.5,
        epsilon_hyper_beta=1.5 / 0.001,
    ),
#     data=dict(alpha=np.ones(d.sizes['sample']) * 1e5),
)

est, history = sf.workflow.simple_fit(
    model_fit.condition(
        **d.metagenotypes.to_counts_and_totals()
    ),
#     stage2_hyperparameters=dict(gamma_hyper=1.0),
#     thresh=0.02,
    initialize_params=dict(
        gamma=approx.genotypes.fuzzed(eps=1e-1).values,
        pi=approx.communities.fuzzed(eps=1e-2).values,
    ),
    lagA=10,
    lagB=100,
    opt=pyro.optim.Adamax({"lr": 1e-0}, {"clip_norm": 100}),
    seed=1,
)

sf.plot.plot_loss_history(history)

In [None]:
sf.plot.plot_community(
    est,
    col_colors_func=lambda w: xr.Dataset(dict(
        mgen_entropy=w.metagenotypes.entropy(),
        expect_entropy=w.data['p'].pipe(sf.math.binary_entropy).mean("position"),
        mean_cvrg=w.metagenotypes.sum("allele").mean("position"),
        m_hyper_r=w.data['m_hyper_r'],
        alpha=w.data['alpha'].pipe(np.log),
        flag=(w.data.alpha < 10) & (w.metagenotypes.sum("allele").mean("position") > 20),
    )),
    row_colors_func=lambda w: xr.Dataset(dict(
        entropy=w.genotypes.entropy(),
        mean_cvrg=(w.communities.data * w.metagenotypes.sum("allele").mean("position")).sum("sample").pipe(np.log),
    )),
    row_linkage_func=lambda w: w.genotypes.cosine_linkage(),
    col_linkage_func=lambda w: w.metagenotypes.linkage('sample'),
#     norm=mpl.colors.SymLogNorm(linthresh=1e-2),
#     norm=mpl.colors.PowerNorm(1),
)

In [None]:
sf.plot.plot_genotype(
    est,
    row_colors_func=lambda w: xr.Dataset(dict(
        entropy=w.genotypes.entropy(),
        mean_cvrg=(w.communities.data * w.metagenotypes.sum("allele").mean("position")).sum("sample").pipe(np.log),
    )),
    row_linkage_func=lambda w: w.genotypes.cosine_linkage(),
    col_linkage_func=lambda w: w.metagenotypes.linkage(dim='position'),
)

In [None]:
sf.plot.plot_missing(
    est,
    row_colors_func=lambda w: xr.Dataset(dict(
        entropy=w.genotypes.entropy(),
        mean_cvrg=(w.communities.data * w.metagenotypes.sum("allele").mean("position")).sum("sample").pipe(np.log),
    )),
    row_linkage_func=lambda w: w.genotypes.cosine_linkage(),
    col_linkage_func=lambda w: w.metagenotypes.linkage(dim='position'),
)

In [None]:
sf.plot.plot_metagenotype(
    est,
#     row_colors_func=lambda w: xr.Dataset(dict(
#         entropy=w.genotypes.entropy(),
#         mean_cvrg=(w.communities.data * w.metagenotypes.sum("allele").mean("position")).sum("sample").pipe(np.log),
#     )),
#     row_linkage_func=lambda w: w.genotypes.cosine_linkage(),
    row_linkage_func=lambda w: w.metagenotypes.linkage(dim='position'),
)

In [None]:
# sample_list = ['DS0097_001', 'DS0097_014', 'DS0097_027', 'DS0097_005', 'SS01057']
sample_list = [
#     'DS0097_032',
#     'DS0044_007',
#     'SS01068',
#     'SS01147',
#     'SS01057',
#     'SS01134',
#     'SS01163',
    'SS01075',
    'SS01078',
]

sf.plot.plot_metagenotype_frequency_spectrum(est, sample_list=sample_list, show_dominant=True)

In [None]:
d = est

model_sim = (
    sf.model.ParameterizedModel(
        sf.model_zoo.full_metagenotype_model_structure,
        coords=dict(
            sample=d.sample.values,
            position=d.position.values,
            allele=d.allele.values,
            strain=d.strain.values,
        ),
    )
)

resim = model_sim.condition(
    pi=d.communities.values,
    gamma=d.genotypes.discretized().fuzzed().values,
#     gamma=d.genotypes.values
    m=d.data['m'].values,
#     epsilon=d.data['epsilon'].values,
#     alpha=d.data['alpha'].values,
    epsilon=np.ones_like(d.data['epsilon'].values) * 1e-5,
    alpha=np.ones_like(d.data['alpha'].values) * 1e5,
).simulate_world()

In [None]:
# sample_list = ['SS01038', 'SS01054','SS01052', 'SS01063', 'DS0485_002', 'DS0097_001', 'DS0097_027']
sample_list = [
    'SS01075',
    'SS01078',
]

fig, axs = plt.subplots(3, 3, figsize=(11, 9), sharey=True)

for sample, ax in zip(sample_list, axs.flatten()):
    sf.plot.plot_metagenotype_frequency_spectrum_comparison(dict(obs=obs, resim=resim), sample=sample, ax=ax)
    ax.set_yscale('log')
    ax.set_ylim(1, 1e4)
plt.legend()

In [None]:
w = est
_data = xr.Dataset(dict(
        mgen_entropy=w.metagenotypes.entropy(),
        expect_entropy=w.data['p'].pipe(sf.math.binary_entropy).mean("position"),
        mean_cvrg=w.metagenotypes.sum("allele").mean("position"),
        alpha=w.data['alpha'].pipe(np.log),
    )).to_dataframe()

plt.scatter(x='expect_entropy', y='mgen_entropy', data=_data, s='mean_cvrg', c='alpha')

In [None]:
flagged_sample = sf.pandas_util.idxwhere(((est.data.alpha < 10) & (est.data.m.mean("position") > 20)).to_series())

In [None]:
sample_list = flagged_sample

fig, axs = plt.subplots(3, 3, figsize=(11, 9), sharey=True)

for sample, ax in zip(sample_list, axs.flatten()):
    sf.plot.plot_metagenotype_frequency_spectrum_comparison(dict(obs=obs, resim=resim), sample=sample, ax=ax)
    ax.set_yscale('log')
    ax.set_ylim(1, 1e4)
plt.legend()

In [None]:
sf.plot.plot_metagenotype(
    est,
    col_linkage_func=lambda w: w.metagenotypes.linkage("sample"),
)

In [None]:
sf.plot.plot_expected_fractions(
    est,
    col_linkage_func=lambda w: w.metagenotypes.linkage("sample"),
)

In [None]:
sf.plot.plot_prediction_error(
    est,
#     col_colors_func=lambda w: xr.Dataset(dict(
#         alpha=w.data.alpha.pipe(np.log),
#     )),
)
