## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sfacts as sf

In [None]:
from sfacts.data import load_input_data, select_informative_positions
import numpy as np
from sfacts.logging_util import info
from sfacts.pandas_util import idxwhere
from sfacts.workflow import fit_to_data
import sfacts as sf
import matplotlib as mpl
import matplotlib.pyplot as plt
import warnings
import torch
import pandas as pd
import scipy as sp
from scipy.spatial.distance import braycurtis, cosine, pdist
from tqdm import tqdm
import seaborn as sns
import pickle
from lib.plot import rotate_xticklabels


def linear_distance(linear_index):
    linear_index = linear_index.to_frame()
    return pd.DataFrame(
        squareform(
            pdist(
                linear_index,
                metric='cityblock'
            )
        ),
        index=linear_index.index,
        columns=linear_index.index,
    )


mpl.rcParams['figure.dpi']= 120

In [None]:
from lib.plot import ordination_plot
from lib.pandas import align_indexes

In [None]:
warnings.filterwarnings(
    "ignore",
    message="torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.",
    category=torch.jit.TracerWarning,
#     module="trace_elbo",  # FIXME: What is the correct regex for module?
#     lineno=5,
)

In [None]:
all_species_position_meta_ = pd.read_table(
    '/pollard/data/gt-pro-db/variants_main.covered.hq.snp_dict.tsv',
    names=['species_id', 'position', 'contig', 'contig_position', 'ref', 'alt']
).set_index('position')
all_species_position_meta_ = all_species_position_meta_[all_species_position_meta_.species_id.isin([100022, 102506])]

In [None]:
all_species_position_meta_.info()

In [None]:
sample_meta = pd.read_table('raw/shi2019s13.tsv').set_index('NCBI Accession Number')
sample_meta.groupby(['Study', 'Continent']).apply(len)

In [None]:
select_studies = ['CM_madagascar', 'Bengtsson-PalmeJ_2015', 'FengQ_2015', 'LiJ_2017', 'LomanNJ_2013']

## All MGEN

### 102506 (Escherichia)

In [None]:
info("Loading input data.")
data = load_input_data(['data/core.sp-102506.gtpro-pileup.nc'])

In [None]:
mrg_ss, data_fit, history = sf.workflow.filter_subsample_and_fit(
    data,
    incid_thresh=0.1,
    cvrg_thresh=0.05,
    npos=1000,
    preclust=False,
#     preclust_kwargs=dict(
#         thresh=0.1,
#         additional_strains_factor=0.,
#         progress=True,
#     ),
    fit_kwargs=dict(
        s=400,
        gamma_hyper=0.01,
        pi_hyper=0.01,
        rho_hyper=0.5,
        mu_hyper_mean=5,
        mu_hyper_scale=5.,
        m_hyper_r=10.,
        delta_hyper_temp=0.1,
        delta_hyper_p=0.9,
        alpha_hyper_hyper_mean=100.,
        alpha_hyper_hyper_scale=10.,
        alpha_hyper_scale=0.5,
#         alpha_hyper_hyper_mean=10000.,
#         alpha_hyper_hyper_scale=0.001,
#         alpha_hyper_scale=0.001,
        epsilon_hyper_alpha=1.5,
        epsilon_hyper_beta=1.5 / 0.01,
        device='cuda',
        lag=100,
        lr=1e-1,
        progress=True
    ),
    postclust_kwargs=dict(
        thresh=0.1,
        progress=True,
    ),
    seed=1,
)

In [None]:
sf.plot.plot_loss_history(history)

In [None]:
plt.hist(np.log10(mrg_ss['alpha']), bins=100)
#plt.yscale('log')
None

In [None]:
plt.hist(np.log10(mrg_ss['epsilon']), bins=100)
None

In [None]:
top_strains = (mrg_ss['pi'] > 0.75).sum(0).argsort()[-50:]
top_samples = ((mrg_ss['pi'][:,top_strains] > 0.25).sum(1)).argsort()[-100:]

sf.plot.plot_community(
    mrg_ss['pi'][top_samples][:, top_strains],
    yticklabels=1,
    row_colors=mpl.cm.viridis(np.log10(mrg_ss['alpha'][top_samples])),
    col_colors=mpl.cm.viridis(sf.evaluation.mean_masked_genotype_entropy(mrg_ss['gamma'], mrg_ss['delta'])[top_strains]),
    norm=mpl.colors.PowerNorm(1/3),
)

In [None]:
grid = sf.plot.plot_genotype(
    sf.genotype.mask_missing_genotype(mrg_ss['gamma'], mrg_ss['delta']), scalex=0.06, scaley=0.01, dheight=4, dwidth=0.2, xticklabels=0, tree_kws=dict(lw=1),
#     col_colors=mpl.cm.viridis(sf.evaluation.mean_masked_genotype_entropy(mrg_ss['gamma'], mrg_ss['delta'])[top_strains]),
    
)

In [None]:
grid = sf.plot.plot_genotype(
    mrg_ss['gamma'][top_strains],
    col_colors=mpl.cm.viridis(sf.evaluation.mean_masked_genotype_entropy(mrg_ss['gamma'], mrg_ss['delta'])[top_strains]),
)

In [None]:
grid = sf.plot.plot_missing(
    mrg_ss['delta'][top_strains],
    col_colors=mpl.cm.viridis(sf.evaluation.mean_masked_genotype_entropy(mrg_ss['gamma'], mrg_ss['delta'])[top_strains]),
)

In [None]:
p_estimate = sf.genotype.counts_to_p_estimate(
    data_fit.sel(allele='alt'), data_fit.sum('allele')
)

#sf.plot.plot_genotype(p_estimate.values[top_samples])

#### Biogeography

In [None]:
# Construct composition matrix for samples with biogeography data

composition = pd.DataFrame(mrg_ss['pi'], index=data_fit.library_id)
meta = sample_meta.reindex(composition.index).dropna(subset=['Sample ID'])
composition_bg = composition.reindex(meta.index)

In [None]:
d = composition_bg[meta['Study'].isin(['VatanenT_2016'])]
strains = idxwhere((composition_bg[meta['Study'].isin(['VatanenT_2016'])] > 0.5).sum() > 1)

sf.plot.plot_community(
    d.loc[:, strains],
    yticklabels=1,
    norm=mpl.colors.PowerNorm(1/3),
)

In [None]:
# TODO: This is a giant contingency table,
# and the p-value on a chisq test shows clearly that strains clump
# into countries.

contingency = (
    composition_bg
    .groupby(meta['Country'])
    .apply(lambda d: d.idxmax(1).value_counts())
    .unstack(fill_value=0)
)

null_contingency = (
    composition_bg
    .set_index(composition_bg.sample(frac=1.0).index)
    .groupby(meta['Country'])
    .apply(lambda d: d.idxmax(1).value_counts())
    .unstack(fill_value=0)
)
assert sp.stats.chi2_contingency(null_contingency)[1] > 0.01

print(sp.stats.chi2_contingency(contingency))

In [None]:
# Same analysis, but carefully selecting studies that I don't believe have
# multiple metagenomes from same/related individuals.

contingency2 = (
    composition_bg
    [meta['Study'].isin(select_studies)]
    .groupby(meta['Country'])
    .apply(lambda d: d.idxmax(1).value_counts())
    .unstack(fill_value=0)
)

null_contingency2 = (
    composition_bg
    [meta['Study'].isin(select_studies)]
    .set_index(composition_bg[meta['Study'].isin(select_studies)].sample(frac=1.0).index)
    .groupby(meta['Country'])
    .apply(lambda d: d.idxmax(1).value_counts())
    .unstack(fill_value=0)
)
assert sp.stats.chi2_contingency(null_contingency2)[1] > 0.01

print(sp.stats.chi2_contingency(contingency2))

In [None]:
# Same analysis, but carefully selecting studies that I don't believe have
# multiple metagenomes from same/related individuals.
# And clustering by study rather than country.

contingency3 = (
    composition_bg
    [meta['Study'].isin(select_studies)]
    .groupby(meta['Study'])
    .apply(lambda d: d.idxmax(1).value_counts())
    .unstack(fill_value=0)
)

null_contingency3 = (
    composition_bg
    [meta['Study'].isin(select_studies)]
    .set_index(composition_bg[meta['Study'].isin(select_studies)].sample(frac=1.0).index)
    .groupby(meta['Study'])
    .apply(lambda d: d.idxmax(1).value_counts())
    .unstack(fill_value=0)
)
assert sp.stats.chi2_contingency(null_contingency3)[1] > 0.01

print(sp.stats.chi2_contingency(contingency3))

In [None]:
meta[meta['Study'].isin(select_studies)].groupby('Study').apply(len)

In [None]:
count_individuals = meta[meta['Study'].isin(select_studies)].groupby('Country').apply(len)

top_20_strains = contingency2.apply(lambda x: x / x.sum(), axis=1).mean().sort_values(ascending=False).head(20).index

ax = (
    contingency2
    .apply(lambda x: x / x.sum(), axis=1)
    .loc[['CHN', 'MDG', 'AUT', 'DEU', 'SWE'], top_20_strains]
    .plot
    .bar(stacked=True, color=mpl.cm.tab20(np.linspace(0, 1, num=20)))
)
#ax.legend_.set_visible(False)
ax.legend(bbox_to_anchor=(1, 1), title='Top 20 Strains')

ax.set_ylabel('Fraction samples where dominant')

In [None]:
meta.groupby(['Study', 'Country']).apply(len).unstack(fill_value=0).loc[select_studies].T

In [None]:
count_individuals = meta[meta['Study'].isin(select_studies)].groupby('Country').apply(len)

top_20_strains = contingency3.apply(lambda x: x / x.sum(), axis=1).mean().sort_values(ascending=False).head(20).index

ax = (
    contingency3
    .apply(lambda x: x / x.sum(), axis=1)
    .loc[:, top_20_strains]
    .plot
    .bar(stacked=True, color=mpl.cm.tab20(np.linspace(0, 1, num=20)))
)
#ax.legend_.set_visible(False)
ax.legend(bbox_to_anchor=(1, 1), title='Top 20 Strains')

ax.set_ylabel('Fraction samples where dominant')

In [None]:
count_individuals = meta.groupby([meta['Continent'], meta['Country'], meta['Study']]).apply(len)

d = (
    composition_bg
    .groupby([meta['Continent'], meta['Country'], meta['Study']])
    .apply(lambda d: d.idxmax(1).value_counts())
    .unstack(fill_value=0)
    .sort_index()
    .apply(lambda x: x / x.sum(), axis=1)
)
top_strains = d.mean().sort_values(ascending=False).head(15).index

d = d.loc[:, top_strains].assign(other=1 - d.loc[:, top_strains].sum(1)).drop(idxwhere(count_individuals < 10))

ax = (
    d
    .plot
    .bar(
        stacked=True, color=mpl.cm.tab20(np.linspace(0, 1, num=20)),
        figsize=(10, 5)
    )
)
#ax.legend_.set_visible(False)
ax.legend(bbox_to_anchor=(1, 1), title='Top Strains')

ax.set_ylabel('Fraction samples where dominant')
rotate_xticklabels()

In [None]:
sample_pairs = pd.read_table('raw/shi2019s14.tsv').set_index(['Sample Run 1', 'Sample Run 2'])

bc_dist = {}
cos_dist = {}

for libraryA, libraryB in tqdm(sample_pairs.index):
    if (libraryA not in composition_bg.index) or (libraryB not in composition_bg.index):
        continue
    bc_dist[(libraryA, libraryB)] = braycurtis(composition_bg.loc[libraryA], composition_bg.loc[libraryB])
    cos_dist[(libraryA, libraryB)] = cosine(p_estimate.loc[libraryA], p_estimate.loc[libraryB])

sample_pairs = sample_pairs.assign(bc=pd.Series(bc_dist), cos=pd.Series(cos_dist))

In [None]:
sns.stripplot('Group Type', 'bc', data=sample_pairs, alpha=0.2)

In [None]:
sns.stripplot('Group Type', 'cos', data=sample_pairs, alpha=0.2)

In [None]:
sns.jointplot('cos', 'bc', data=sample_pairs, kind='hex', norm=mpl.colors.PowerNorm(1/5))

#### Diversity estimation

In [None]:
from collections import defaultdict

rarefaction = []
strain_counts = defaultdict(lambda: 0)
for strain_id in composition_bg.idxmax(1).sample(frac=1.0).values:
    strain_counts[strain_id] += 1
    rarefaction.append(len(strain_counts))
rarefaction = np.array(rarefaction)

plt.plot(rarefaction)
plt.plot([0, 400], [0, 400], lw=1, linestyle='--', color='k')

In [None]:
strain_incidence = (composition_bg > 1e-1).sum()

observed_total = len(strain_incidence)
observed_singletons = (strain_incidence == 1).sum()
observed_doubletons = (strain_incidence == 2).sum()

chao2 = observed_total + ((observed_singletons**2) / (2 * observed_doubletons))
print(chao2)

In [None]:
from collections import defaultdict

fig = plt.figure(figsize=(5, 5))
rarefaction = []
strain_counts = defaultdict(lambda: 0)
for strain_id in composition.idxmax(1).sample(frac=1.0).values:
    strain_counts[strain_id] += 1
    rarefaction.append(len(strain_counts))
rarefaction = np.array(rarefaction)

plt.plot(rarefaction)
plt.plot([0, 250], [0, 250], lw=1, linestyle='--', color='k')

plt.ylabel('Observed genotype clusters')
plt.xlabel('Number of samples')
plt.title("Escherichia")

In [None]:
strain_incidence = (composition > 1e-1).sum()

observed_total = len(strain_incidence)
observed_singletons = (strain_incidence == 1).sum()
observed_doubletons = (strain_incidence == 2).sum()

chao2 = observed_total + ((observed_singletons**2) / (2 * observed_doubletons))
print(chao2)

#### Full Length

In [None]:
est, data_filt, informative_positions, position_ss = sf.workflow.filter_subsample_fit_and_refit_genotypes(
    data,
    incid_thresh=0.1,
    cvrg_thresh=0.05,
    npos=1000,
    preclust=False,
#     preclust_kwargs=dict(
#         thresh=0.1,
#         additional_strains_factor=0.,
#         progress=True,
#     ),
    fit_kwargs=dict(
        s=400,
        gamma_hyper=0.01,
        pi_hyper=0.01,
        rho_hyper=0.5,
        mu_hyper_mean=5,
        mu_hyper_scale=5.,
        m_hyper_r=10.,
        delta_hyper_temp=0.1,
        delta_hyper_p=0.9,
        alpha_hyper_hyper_mean=100.,
        alpha_hyper_hyper_scale=10.,
        alpha_hyper_scale=0.5,
#         alpha_hyper_hyper_mean=10000.,
#         alpha_hyper_hyper_scale=0.001,
#         alpha_hyper_scale=0.001,
        epsilon_hyper_alpha=1.5,
        epsilon_hyper_beta=1.5 / 0.01,
        device='cuda',
        lag=100,
        lr=1e-1,
        progress=True
    ),
    postclust_kwargs=dict(
        thresh=0.1,
        progress=True,
    ),
    seed=1,
)

pickle.dump(est, open('data/core.sp-102506.gtpro-pileup.sf-est.pickle', 'wb'))

##### LD

In [None]:
est['gamma'].shape

In [None]:
sf.plot.plot_genotype(est['gamma'][:,:100], row_cluster=False)

In [None]:
plt.hist(est['gamma'].mean(0), bins=100)
None

In [None]:
plt.hist((est['pi'] @ est['gamma']).mean(0), bins=100)
None

In [None]:
position_meta = all_species_position_meta_[lambda x: x.species_id == 102506]

In [None]:
position_meta.loc[informative_positions]

In [None]:
from scipy.spatial.distance import squareform, pdist

def pos_psim(gamma, delta):
    gamma_ = sf.genotype.mask_missing_genotype(gamma, delta)
    return (1 - squareform(pdist((gamma_.T), metric='correlation')))**2

position_sim = pd.DataFrame(pos_psim(est['gamma'], est['delta']), index=informative_positions, columns=informative_positions)

In [None]:
(1 - squareform(1 - position_sim)).mean()

In [None]:
snp_info = (
    position_meta
    .groupby('contig')
    .apply(len)
    .to_frame(name='total_count')
    .assign(
        fit_count=position_meta.loc[informative_positions]
        .groupby('contig')
        .apply(len)
    ).fillna(0)
).sort_values('fit_count', ascending=False)

snp_info.head(10)

In [None]:
position_ldist_ = linear_distance(
    position_meta.loc[informative_positions]['contig_position']
).sort_index().sort_index(1)

In [None]:
import patsy

same_contig = pd.DataFrame(
    1 - squareform(
        pdist(
            patsy.dmatrix(
                'contig - 1', data=position_meta.loc[informative_positions]['contig'].to_frame(), return_type='dataframe'
            ),
            'jaccard'),
    ),
    index=informative_positions, columns=informative_positions,

)
#sns.heatmap(same_contig.sort_index().sort_index(1))

In [None]:
ld_data = pd.DataFrame(dict(
    linear_distance=squareform(position_ldist_.values),
    same_contig=(squareform(1 - same_contig.values) == 0),
    ld=1 - squareform(1 - position_sim),
))
ld_data = ld_data[ld_data.same_contig]
ld_data

In [None]:
d = ld_data[
        lambda x: x.same_contig & (150 < x.linear_distance) & (x.linear_distance < 2000)
]

plt.scatter(
    x='linear_distance',
    y='ld',
    data=d,
    s=1,
    alpha=0.1,
)


In [None]:
sns.jointplot(
    x='linear_distance',
    y='ld',
    data=ld_data[
        lambda x: x.same_contig & (0 < x.linear_distance) & (x.linear_distance < 2000)
    ],
    kind='hex',
    norm=mpl.colors.PowerNorm(1/3)
)

In [None]:
ld_data[
        lambda x: x.same_contig & (0 < x.linear_distance) & (x.linear_distance < 100)
    ].ld.mean()

In [None]:
ld_data[
        lambda x: x.same_contig & (100 < x.linear_distance) & (x.linear_distance < 200)
    ].ld.mean()

In [None]:
stepsize = 25
right = 5000

d = ld_data[ld_data.linear_distance < right]

bins = {}
for start in range(0, right, stepsize):
    stop = start + stepsize
    bins[start] = d[(d.linear_distance >= start) & (d.linear_distance < stop)].ld.mean()
    
plt.scatter(
    x='linear_distance',
    y='ld',
    data=d,
    s=1,
    alpha=0.05,
    color='black',
    label='__nolegend__',
)
plt.scatter([], [], s=10, color='black', label='Locus Pair')
plt.plot(pd.Series(bins), color='red', label='Mean LD (25 bp Bin)')
plt.axhline(0, lw=1, color='red', linestyle='--')
plt.ylabel(r"LD")
plt.xlabel("Distance")
plt.legend(bbox_to_anchor=(0.85, 1.15), ncol=2)

print(sp.stats.spearmanr(d['linear_distance'], d['ld']))

### 100022 (F. prausnitzii)

In [None]:
info("Loading input data.")
data = load_input_data(['data/core.sp-100022.gtpro-pileup.nc'])

In [None]:
mrg_ss, data_fit, history = sf.workflow.filter_subsample_and_fit(
    data,
    incid_thresh=0.1,
    cvrg_thresh=0.5,
    npos=500,
    preclust=False,
#     preclust_kwargs=dict(
#         thresh=0.1,
#         additional_strains_factor=0.,
#         progress=True,
#     ),
    fit_kwargs=dict(
        s=1000,
        gamma_hyper=0.01,
        pi_hyper=0.01,
        rho_hyper=0.5,
        mu_hyper_mean=5,
        mu_hyper_scale=5.,
        m_hyper_r=10.,
        delta_hyper_temp=0.1,
        delta_hyper_p=0.9,
        alpha_hyper_hyper_mean=100.,
        alpha_hyper_hyper_scale=10.,
        alpha_hyper_scale=0.5,
        epsilon_hyper_alpha=1.5,
        epsilon_hyper_beta=1.5 / 0.01,
        device='cuda',
        lag=100,
        lr=1e-1,
        progress=True
    ),
    postclust_kwargs=dict(
        thresh=0.25,
        progress=True,
    ),
    seed=1,
)

In [None]:
sf.plot.plot_loss_history(history)

In [None]:
plt.hist(np.log10(mrg_ss['alpha']), bins=100)
#plt.yscale('log')
None

In [None]:
plt.hist(np.log10(mrg_ss['epsilon']), bins=100)
None

In [None]:
top_strains = (mrg_ss['pi'] > 0.75).sum(0).argsort()[-50:]
top_samples = ((mrg_ss['pi'][:,top_strains] > 0.25).sum(1)).argsort()[-100:]

sf.plot.plot_community(
    mrg_ss['pi'][top_samples][:, top_strains],
    yticklabels=1,
    row_colors=mpl.cm.viridis(np.log10(mrg_ss['alpha'][top_samples])),
    col_colors=mpl.cm.viridis(sf.evaluation.mean_masked_genotype_entropy(mrg_ss['gamma'], mrg_ss['delta'])[top_strains]),
    norm=mpl.colors.PowerNorm(1/3),
)

In [None]:
grid = sf.plot.plot_genotype(
    mrg_ss['gamma'][top_strains],
    col_colors=mpl.cm.viridis(sf.evaluation.mean_masked_genotype_entropy(mrg_ss['gamma'], mrg_ss['delta'])[top_strains]),
)

In [None]:
grid = sf.plot.plot_missing(
    mrg_ss['delta'][top_strains],
    col_colors=mpl.cm.viridis(sf.evaluation.mean_masked_genotype_entropy(mrg_ss['gamma'], mrg_ss['delta'])[top_strains]),
)

In [None]:
p_estimate = sf.genotype.counts_to_p_estimate(
    data_fit.sel(allele='alt'), data_fit.sum('allele')
)

sf.plot.plot_genotype(p_estimate.values[top_samples])

In [None]:
composition = pd.DataFrame(mrg_ss['pi'], index=data_fit.library_id)
composition = composition.reindex(meta.index)

In [None]:
d = composition[meta['Study'].isin(['VatanenT_2016'])]
strains = idxwhere((composition[meta['Study'].isin(['VatanenT_2016'])] > 0.5).sum() > 1)

sf.plot.plot_community(
    d.loc[:, strains],
    yticklabels=1,
    norm=mpl.colors.PowerNorm(1/3),
)

In [None]:
# TODO: This is a giant contingency table,
# and the p-value on a chisq test shows clearly that strains clump
# into countries.

contingency = (
    composition
    .groupby(meta['Country'])
    .apply(lambda d: d.idxmax(1).value_counts())
    .unstack(fill_value=0)
)

null_contingency = (
    composition
    .set_index(composition.sample(frac=1.0).index)
    .groupby(meta['Country'])
    .apply(lambda d: d.idxmax(1).value_counts())
    .unstack(fill_value=0)
)
assert sp.stats.chi2_contingency(null_contingency)[1] > 0.01

print(sp.stats.chi2_contingency(contingency))

In [None]:
sample_pairs = pd.read_table('raw/shi2019s14.tsv').set_index(['Sample Run 1', 'Sample Run 2'])

bc_dist = {}
cos_dist = {}

for libraryA, libraryB in tqdm(sample_pairs.index):
    if (libraryA not in composition.index) or (libraryB not in composition.index):
        continue
    bc_dist[(libraryA, libraryB)] = braycurtis(composition.loc[libraryA], composition.loc[libraryB])
    cos_dist[(libraryA, libraryB)] = cosine(p_estimate.loc[libraryA], p_estimate.loc[libraryB])

sample_pairs = sample_pairs.assign(bc=pd.Series(bc_dist), cos=pd.Series(cos_dist))

In [None]:
sns.stripplot('Group Type', 'bc', data=sample_pairs, alpha=0.2)

In [None]:
sns.stripplot('Group Type', 'cos', data=sample_pairs, alpha=0.2)

In [None]:
sns.jointplot('cos', 'bc', data=sample_pairs, kind='hex', norm=mpl.colors.PowerNorm(1/3))

In [None]:
from collections import defaultdict

rarefaction = []
strain_counts = defaultdict(lambda: 0)
for strain_id in composition.idxmax(1).sample(frac=1.0).values:
    strain_counts[strain_id] += 1
    rarefaction.append(len(strain_counts))
rarefaction = np.array(rarefaction)

plt.plot(rarefaction)
plt.plot([0, 400], [0, 400], lw=1, linestyle='--', color='k')

In [None]:
strain_incidence = (composition > 1e-1).sum()

observed_total = len(strain_incidence)
observed_singletons = (strain_incidence == 1).sum()
observed_doubletons = (strain_incidence == 2).sum()

chao2 = observed_total + ((observed_singletons**2) / (2 * observed_doubletons))
print(chao2)

In [None]:
est, data_filt, position_ss = sf.workflow.filter_subsample_fit_and_refit_genotypes(
    data,
    incid_thresh=0.1,
    cvrg_thresh=0.25,
    npos=500,
    preclust=False,
#     preclust_kwargs=dict(
#         thresh=0.1,
#         additional_strains_factor=0.,
#         progress=True,
#     ),
    fit_kwargs=dict(
        s=1000,
        gamma_hyper=0.01,
        pi_hyper=0.01,
        rho_hyper=0.5,
        mu_hyper_mean=5,
        mu_hyper_scale=5.,
        m_hyper_r=10.,
        delta_hyper_temp=0.1,
        delta_hyper_p=0.9,
        alpha_hyper_hyper_mean=100.,
        alpha_hyper_hyper_scale=10.,
        alpha_hyper_scale=0.5,
#         alpha_hyper_hyper_mean=10000.,
#         alpha_hyper_hyper_scale=0.001,
#         alpha_hyper_scale=0.001,
        epsilon_hyper_alpha=1.5,
        epsilon_hyper_beta=1.5 / 0.01,
        device='cuda',
        lag=100,
        lr=1e-1,
        progress=True
    ),
    postclust_kwargs=dict(
        thresh=0.25,
        progress=True,
    ),
    seed=1,
)

pickle.dump(est, open('data/core.sp-100022.gtpro-pileup.sf-est.pickle', 'wb'))

In [None]:
est['gamma'].shape

In [None]:
sf.plot.plot_genotype(est['gamma'], row_cluster=True)

In [None]:
plt.hist(est['gamma'].mean(0), bins=100)
None

In [None]:
plt.hist((est['pi'] @ est['gamma']).mean(0), bins=100)
None

In [None]:
position_meta = all_species_position_meta_[lambda x: x.species_id == 100022]

In [None]:
position_meta.loc[data_filt.position]