## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sfacts as sf

In [None]:
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import scipy as sp
import pyro
import pyro.distributions as dist
import torch
from functools import partial
from tqdm import tqdm
import xarray as xr
import warnings
from torch.jit import TracerWarning

In [None]:
warnings.filterwarnings(
    "ignore",
    message="torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.",
    category=torch.jit.TracerWarning,
#     module="trace_elbo",  # FIXME: What is the correct regex for module?
#     lineno=5,
)

device = 'cpu'

## Prototype

### Load Data

In [None]:
# Sanity check on sfacts/data.py
np.random.seed(1)

mgen_all = (
    sf.data.Metagenotypes.load('data/core.sp-100022.gtpro-pileup.nc', validate=False)
)

mgen_filt = (
    mgen_all
    .select_variable_positions(thresh=0.02)
    .select_samples_with_coverage(0.1)
)
print(mgen_filt.sizes)

In [None]:
# sf.plot.plot_metagenotype(
#     obs
# )

### Fitting

In [None]:
s = 500

In [None]:
est = (
    sf.workflow.fit_subsampled_metagenotype_collapse_strains_then_iteratively_refit_full_genotypes(
        sf.model_zoo.full_metagenotype_dirichlet_rho_model_structure,
        mgen_filt.random_sample(500, 'position'),
        nstrain=s,
        nposition=500,
        thresh=0.01,
        hyperparameters=dict(
            gamma_hyper=0.1,
            delta_hyper_r=0.8,
            delta_hyper_temp=0.1,
            rho_hyper=1.0,
            pi_hyper=1.0,
            epsilon_hyper_mode=0.01,
            epsilon_hyper_spread=1.5,
            alpha_hyper_hyper_mean=1000.0,
            alpha_hyper_hyper_scale=0.5,
            alpha_hyper_scale=1.0,
        ),
        stage2_hyperparameters=dict(gamma_hyper=1.0),
        device=device,
        quiet=False,
        estimation_kwargs=dict(
            lagA=10,
            lagB=100,
            opt=pyro.optim.Adamax({"lr": 1e-0}, {"clip_norm": 100}),
            seed=2,
            jit=True,
        )
    )
)

# sf.plot.plot_loss_history(history0)
# sf.plot.plot_loss_history(history1)
# sf.plot.plot_loss_history(history2)
# sf.plot.plot_loss_history(history3)

### Viz.

In [None]:
sf.plot.plot_community(
    est,
    col_colors_func=lambda w: xr.Dataset(dict(
        mgen_entropy=w.metagenotypes.entropy(),
        expect_entropy=w.data['p'].pipe(sf.math.binary_entropy).mean("position"),
        mean_cvrg=w.metagenotypes.sum("allele").mean("position").pipe(np.sqrt),
        m_hyper_r=w.data['m_hyper_r'],
        alpha=w.data['alpha'].pipe(np.sqrt),
        flag=(w.data.alpha < 10) & (w.metagenotypes.sum("allele").mean("position") > 20),
    )),
    row_colors_func=lambda w: xr.Dataset(dict(
        entropy=w.genotypes.entropy(),
        mean_cvrg=(w.communities.data * w.metagenotypes.sum("allele").mean("position")).sum("sample").pipe(np.sqrt),
    )),
    scalex=0.1, scaley=0.1,row_linkage_func=None,col_linkage_func=None,
#     col_linkage_func=lambda w: w.genotypes.cosine_linkage(),
#     row_linkage_func=lambda w: sf.data.latent_metagenotypes_linkage(w),
)

In [None]:
sf.plot.plot_genotype(
    est1,
    row_colors_func=lambda w: xr.Dataset(dict(
        entropy=w.genotypes.entropy(),
        mean_cvrg=(w.communities.data * w.metagenotypes.sum("allele").mean("position")).sum("sample").pipe(np.log),
    )),
    row_linkage_func=None,
    col_linkage_func=None,
#     row_linkage_func=lambda w: w.genotypes.cosine_linkage(),
#     col_linkage_func=lambda w: w.metagenotypes.linkage(dim='position'),
    isel=dict(position=slice(0, 500)),
)

In [None]:
sf.plot.plot_metagenotype(
    est3,
#     row_colors_func=lambda w: xr.Dataset(dict(
#         entropy=w.genotypes.entropy(),
#         mean_cvrg=(w.communities.data * w.metagenotypes.sum("allele").mean("position")).sum("sample").pipe(np.log),
#     )),
#     row_linkage_func=lambda w: w.genotypes.cosine_linkage(),
    row_linkage_func=lambda w: w.metagenotypes.linkage(dim='position'),
    col_linkage_func=lambda w: sf.data.latent_metagenotypes_linkage(w),
)

In [None]:
plt.hist(est3.data.communities.max("sample").pipe(np.log10), bins=50)

In [None]:
plt.hist(est3.data.alpha.pipe(np.log10), bins=21)

In [None]:
plt.hist(est3.data.epsilon.pipe(np.log10), bins=21)

In [None]:
plt.hist(est3.data.m_hyper_r.pipe(np.log10), bins=21)

In [None]:
plt.scatter(est3.data.m_hyper_r, est3.data.alpha, c=est3.data.mu, alpha=0.5, s=5)
plt.yscale('log')
plt.xscale('log')

In [None]:
# sample_list = ['DS0097_001', 'DS0097_014', 'DS0097_027', 'DS0097_005', 'SS01057']
sample_list = sf.pandas_util.idxwhere(est3.sample.str.startswith('DS0097').to_pandas())[:5]
# sample_list = [
#     'DS0097_032',
#     'DS0044_002',
#     'DS0044_005',
#     'DS0044_006',
#     'DS0044_007',
#     'DS0044_008',
#     'DS0044_009',
#     'DS0044_010',
#     'SS01068',
#     'SS01147',
#     'SS01057',
#     'SS01134',
#     'SS01163',
#     'SS01171',
#     'SS01172',
#     'SS01026',
#     'SS01022',
# ]

sf.plot.plot_metagenotype_frequency_spectrum(est3, sample_list=sample_list, show_dominant=True)

In [None]:
d = est2

model_sim = (
    sf.model.ParameterizedModel(
        sf.model_zoo.full_metagenotype_model_structure,
        coords=dict(
            sample=d.sample.values,
            position=d.position.values,
            allele=d.allele.values,
            strain=d.strain.values,
        ),
    )
)

resim = model_sim.condition(
    pi=d.communities.values,
#     gamma=d.genotypes.discretized().fuzzed().values,
    gamma=d.genotypes.values,
    m=d.data['m'].values,
    epsilon=d.data['epsilon'].values,
    alpha=d.data['alpha'].values,
#     epsilon=np.ones_like(d.data['epsilon'].values) * 1e-5,
#     alpha=np.ones_like(d.data['alpha'].values) * 1e5,
).simulate_world()

In [None]:
# sample_list = ['SS01038', 'SS01054','SS01052', 'SS01063', 'DS0485_002', 'DS0097_001', 'DS0097_027']
sample_list = sf.pandas_util.idxwhere(est2.sample.str.startswith('DS0097').to_pandas())

fig, axs = plt.subplots(3, 3, figsize=(11, 9), sharey=True)

for sample, ax in zip(sample_list, axs.flatten()):
    sf.plot.plot_metagenotype_frequency_spectrum_comparison(dict(obs=obs, resim=resim), sample=sample, ax=ax)
    ax.set_yscale('log')
    ax.set_ylim(1, 1e4)
plt.legend()

In [None]:
w = est3
_data = xr.Dataset(dict(
        mgen_entropy=w.metagenotypes.entropy(),
        expect_entropy=w.data['p'].pipe(sf.math.binary_entropy).mean("position"),
        mean_cvrg=w.metagenotypes.sum("allele").mean("position"),
        alpha=w.data['alpha'].pipe(np.log),
    )).to_dataframe()

plt.scatter(x='expect_entropy', y='mgen_entropy', data=_data, s='mean_cvrg', c='alpha')

### Big Data Viz.

In [None]:
sns.clustermap(est.genotypes.to_pandas().T, vmin=0, vmax=1, center=0.5, cmap='coolwarm')

In [None]:
inferred_sample_coverage = est.metagenotypes.sum('allele').mean('position')
plt.hist(inferred_sample_coverage.pipe(np.log10), bins=50)

plt.xlabel('log10(mean species/sample coverage)')
plt.ylabel('count')
None

In [None]:
total_inferred_strain_coverage = (est.metagenotypes.sum('allele').mean('position') * est.communities.data).sum('sample')
plt.hist(total_inferred_strain_coverage.pipe(np.log10), bins=50)
None

In [None]:
max_single_sample_inferred_strain_coverage = (est.metagenotypes.sum('allele').mean('position') * est.communities.data).max('sample')
plt.hist(max_single_sample_inferred_strain_coverage.pipe(np.log10), bins=50)
None

In [None]:
plt.scatter(max_single_sample_inferred_strain_coverage, total_inferred_strain_coverage, s=5)
plt.plot([0, 1e3], [0, 1e3])
plt.yscale('log')
plt.xscale('log')

In [None]:
sns.clustermap(est.missingness.to_pandas().T, vmin=0, vmax=1)

In [None]:
sns.clustermap(est.communities.to_pandas(), vmin=0, vmax=1, norm=mpl.colors.PowerNorm(1/2))

In [None]:
plt.hist(est.communities.max('sample'), bins=np.linspace(0, 1, num=51))
None

In [None]:
plt.hist(est.genotypes.values.flatten(), bins=np.linspace(0, 1, num=51))
None

In [None]:
plt.hist(est.missingness.values.flatten(), bins=np.linspace(0, 1, num=51))
None

### Biogeography

In [None]:
sample_meta = pd.read_table('raw/shi2019s13.tsv').set_index('NCBI Accession Number')
sample_meta.groupby(['Study', 'Continent']).apply(len)

In [None]:
# Construct composition matrix for samples with biogeography data

composition = est.communities.to_pandas()
meta = sample_meta.reindex(composition.index).dropna(subset=['Sample ID'])
composition_bg = composition.reindex(meta.index)

In [None]:
from sfacts.pandas_util import idxwhere

d = composition_bg[meta['Study'].isin(['VatanenT_2016'])]
strains = idxwhere((composition_bg[meta['Study'].isin(['VatanenT_2016'])] > 0.5).sum() > 1)

# sf.plot.plot_community(
#     d.loc[:, strains],
#     yticklabels=1,
#     norm=mpl.colors.PowerNorm(1/3),
# )

In [None]:
# TODO: This is a giant contingency table,
# and the p-value on a chisq test shows clearly that strains clump
# into countries.

contingency = (
    composition_bg
    .groupby(meta['Country'])
    .apply(lambda d: d.idxmax(1).value_counts())
    .unstack(fill_value=0)
)

null_contingency = (
    composition_bg
    .set_index(composition_bg.sample(frac=1.0).index)
    .groupby(meta['Country'])
    .apply(lambda d: d.idxmax(1).value_counts())
    .unstack(fill_value=0)
)
assert sp.stats.chi2_contingency(null_contingency)[1] > 0.01

print(sp.stats.chi2_contingency(contingency))

In [None]:
# Same analysis, but carefully selecting studies that I don't believe have
# multiple metagenomes from same/related individuals.

contingency2 = (
    composition_bg
    [meta['Study'].isin(select_studies)]
    .groupby(meta['Country'])
    .apply(lambda d: d.idxmax(1).value_counts())
    .unstack(fill_value=0)
)

null_contingency2 = (
    composition_bg
    [meta['Study'].isin(select_studies)]
    .set_index(composition_bg[meta['Study'].isin(select_studies)].sample(frac=1.0).index)
    .groupby(meta['Country'])
    .apply(lambda d: d.idxmax(1).value_counts())
    .unstack(fill_value=0)
)
assert sp.stats.chi2_contingency(null_contingency2)[1] > 0.01

print(sp.stats.chi2_contingency(contingency2))

In [None]:
# Same analysis, but carefully selecting studies that I don't believe have
# multiple metagenomes from same/related individuals.
# And clustering by study rather than country.

contingency3 = (
    composition_bg
    [meta['Study'].isin(select_studies)]
    .groupby(meta['Study'])
    .apply(lambda d: d.idxmax(1).value_counts())
    .unstack(fill_value=0)
)

null_contingency3 = (
    composition_bg
    [meta['Study'].isin(select_studies)]
    .set_index(composition_bg[meta['Study'].isin(select_studies)].sample(frac=1.0).index)
    .groupby(meta['Study'])
    .apply(lambda d: d.idxmax(1).value_counts())
    .unstack(fill_value=0)
)
assert sp.stats.chi2_contingency(null_contingency3)[1] > 0.01

print(sp.stats.chi2_contingency(contingency3))

In [None]:
meta[meta['Study'].isin(select_studies)].groupby('Study').apply(len)

In [None]:
count_individuals = meta[meta['Study'].isin(select_studies)].groupby('Country').apply(len)

top_20_strains = contingency2.apply(lambda x: x / x.sum(), axis=1).mean().sort_values(ascending=False).head(20).index

ax = (
    contingency2
    .apply(lambda x: x / x.sum(), axis=1)
    .loc[['CHN', 'MDG', 'AUT', 'DEU', 'SWE'], top_20_strains]
    .plot
    .bar(stacked=True, color=mpl.cm.tab20(np.linspace(0, 1, num=20)))
)
#ax.legend_.set_visible(False)
ax.legend(bbox_to_anchor=(1, 1), title='Top 20 Strains')

ax.set_ylabel('Fraction samples where dominant')

In [None]:
meta.groupby(['Study', 'Country']).apply(len).unstack(fill_value=0).loc[select_studies].T

In [None]:
count_individuals = meta[meta['Study'].isin(select_studies)].groupby('Country').apply(len)

top_20_strains = contingency3.apply(lambda x: x / x.sum(), axis=1).mean().sort_values(ascending=False).head(20).index

ax = (
    contingency3
    .apply(lambda x: x / x.sum(), axis=1)
    .loc[:, top_20_strains]
    .plot
    .bar(stacked=True, color=mpl.cm.tab20(np.linspace(0, 1, num=20)))
)
#ax.legend_.set_visible(False)
ax.legend(bbox_to_anchor=(1, 1), title='Top 20 Strains')

ax.set_ylabel('Fraction samples where dominant')

In [None]:
count_individuals = meta.groupby([meta['Continent'], meta['Country'], meta['Study']]).apply(len)

d = (
    composition_bg
    .groupby([meta['Continent'], meta['Country'], meta['Study']])
    .apply(lambda d: d.idxmax(1).value_counts())
    .unstack(fill_value=0)
    .sort_index()
    .apply(lambda x: x / x.sum(), axis=1)
)
top_strains = d.mean().sort_values(ascending=False).head(15).index

d = d.loc[:, top_strains].assign(other=1 - d.loc[:, top_strains].sum(1)).drop(idxwhere(count_individuals < 10))

ax = (
    d
    .plot
    .bar(
        stacked=True, color=mpl.cm.tab20(np.linspace(0, 1, num=20)),
        figsize=(10, 5)
    )
)
#ax.legend_.set_visible(False)
ax.legend(bbox_to_anchor=(1, 1), title='Top Strains')

ax.set_ylabel('Fraction samples where dominant')
# rotate_xticklabels()