# Preamble

In [None]:
%load_ext autoreload
%autoreload 0
%autoreload

## Imports

In [None]:
import xarray as xr
import sqlite3
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from lib.pandas_util import idxwhere
import warnings
import torch
import pyro
import scipy as sp


import sfacts as sf
from sfacts.model_zoo import (
    full_metagenotype,
    full_metagenotype_no_missing,
    full_metagenotype_pp_pi,
    full_metagenotype_dirichlet_rho,
    full_metagenotype_dirichlet_rho_no_missing,
    full_metagenotype_pp_pi_no_missing,
#     full_metagenotype_special_meta,
#     full_metagenotype_no_missing_special_meta,
)

from sklearn.cluster import AgglomerativeClustering


from lib.project_style import color_palette, major_allele_frequency_bins
from lib.project_data import metagenotype_db_to_xarray
from lib.plot import ordination_plot, mds_ordination, nmds_ordination
import lib.plot
from lib.plot import construct_ordered_pallete

In [None]:
warnings.filterwarnings(
    "ignore",
    message="torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.",
    category=torch.jit.TracerWarning,
#     module="trace_elbo",  # FIXME: What is the correct regex for module?
#     lineno=5,
)

In [None]:
mpl.rcParams['figure.dpi'] = 200

## Style

## Load Data

### Set Focal Species

In [None]:
species_of_interest = '104345'

### Load from DB

In [None]:
con = sqlite3.connect('data/core.a.proc.gtpro.2.denorm.db')

In [None]:
_mgen_all = pd.read_sql(f"""
SELECT
  lib_id
, species_position
, reference_tally
, alternative_tally
FROM snp_x_lib
JOIN lib USING (lib_id)
WHERE species_id = '{species_of_interest}'
  AND lib_type = 'metagenome'
""", con=con, index_col=['lib_id', 'species_position'])

mgen_all = sf.data.Metagenotypes(metagenotype_db_to_xarray(_mgen_all))

In [None]:
mgen_filt = (
    mgen_all
    .select_variable_positions(thresh=0.05)
    .select_samples_with_coverage(0.05)
)
mgen_ss = mgen_filt.random_sample(position=1000)

In [None]:
mgen_filt.sizes

In [None]:
_drplt_all = pd.read_sql(f"""
SELECT
  lib_id
, species_position
, reference_tally
, alternative_tally
FROM snp_x_lib
JOIN lib USING (lib_id)
WHERE species_id = '{species_of_interest}'
  AND lib_type = 'droplet'
""", con=con, index_col=['lib_id', 'species_position'])

drplt_all = sf.data.Metagenotypes(metagenotype_db_to_xarray(_drplt_all))

In [None]:
drplt_to_sample = pd.read_sql("SELECT lib_id, sample_id FROM lib WHERE lib_type = 'droplet'", con=con, index_col='lib_id').squeeze()
drplt_to_sample.value_counts()

In [None]:
drplt_filt = drplt_all.mlift('reindex', position=mgen_filt.position, fill_value=0).select_samples_with_coverage(0.01)

In [None]:
drplt_filt.sizes

# Analysis

## Depth duplication across droplets

In [None]:
nposition = min(mgen_filt.sizes['position'], 1000)

d = (
        sf.data.Metagenotypes.concat(
        dict(
            mgen=mgen_filt,
            drplt=drplt_filt,
        ),
        dim='sample'
    )
    .select_samples_with_coverage(0.01)
    .random_sample(position=nposition)
)

sf.plot.plot_depth(
    d.to_world(),
    col_colors_func=lambda w: xr.Dataset(dict(
        focal=w.sample.isin(['mgen_SS01009.m', 'mgen_SS01057.m']),
        drplt=w.sample.str.startswith('drplt'),
        depth=w.metagenotypes.mean_depth().pipe(np.log),
        mgen_entropy=w.metagenotypes.entropy('sample'),
    )),
)

sf.plot.plot_metagenotype(
    d.to_world(),
    col_colors_func=lambda w: xr.Dataset(dict(
        focal=w.sample.isin(['mgen_SS01009.m', 'mgen_SS01057.m']),
        drplt=w.sample.str.startswith('drplt'),
        depth=w.metagenotypes.mean_depth().pipe(np.log),
        mgen_entropy=w.metagenotypes.entropy('sample'),
    )),
)

In [None]:
distance_threshold = 0.5

sf.plot.plot_metagenotype(
    drplt_filt.to_world(),
    col_colors_func=lambda d: pd.Series(
        (
            AgglomerativeClustering(
                affinity='cosine',
                distance_threshold=distance_threshold,
                n_clusters=None,
                linkage='complete',
            )
            .fit(d.metagenotypes.total_counts())
            .labels_
        ),
        d.sample,
        name='clust',
    ).to_xarray(),
#     col_linkage_func=lambda w: w.metagenotypes.linkage(pseudo=1.),
    row_col_annotation_cmap=mpl.cm.hsv,
    isel=dict(position=slice(0, 1000))
)

In [None]:
d = drplt_filt

drplt_agg = pd.Series(
    (
        AgglomerativeClustering(
            affinity='cosine',
            distance_threshold=distance_threshold,
            n_clusters=None,
            linkage='complete'
        )
        .fit(d.total_counts())
        .labels_
    ),
    d.sample,
    name='clust',
)

drplt_agg_pal = construct_ordered_pallete(drplt_agg, cm='hsv')

ordination_plot(
    d.pdist(),
    ordin=lib.plot.nmds_ordination,
    meta=drplt_agg.to_frame(),
    colorby='clust',
    color_palette=drplt_agg_pal,
    ordin_kws={'is_dmat': True},
    fill_legend=False,
    scatter_kws=dict(lw=0),
)
None

In [None]:
d = drplt_filt

ordination_plot(
    d.pdist(),
    ordin=lib.plot.tsne_ordination,
    meta=drplt_agg.to_frame(),
    colorby='clust',
    color_palette=drplt_agg_pal,
    ordin_kws={'is_dmat': True},
    fill_legend=False,
    scatter_kws=dict(lw=0),
)
None

In [None]:
drplt_filt_derep = sf.data.Metagenotypes(
    drplt_filt
    .to_series()
    .reset_index()
    .rename(columns={0: 'tally'})
    .join(drplt_agg, on='sample')
    .join(drplt_to_sample, on='sample')
    .groupby(['sample_id', 'clust', 'position', 'allele'])
    .sum()
    .astype(int)
    .reset_index()
    .assign(label=lambda x: x.sample_id + '_' + x.clust.astype(str))
    .set_index(['label', 'position', 'allele'])
    .rename_axis(index={'label': 'sample'})
    .tally
    .to_xarray()
)
print(drplt_filt.sizes, drplt_filt_derep.sizes)

In [None]:
d = drplt_filt_derep

ax, *_ = ordination_plot(
    d.pdist(),
    ordin=lib.plot.nmds_ordination,
    meta=d.sample.to_series().to_frame(),
    ordin_kws={'is_dmat': True},
    fill_legend=False,
    scatter_kws=dict(lw=0),
)

ax.set_xlabel('PCo1')
ax.set_ylabel('PCo2')

In [None]:
d = (
        sf.data.Metagenotypes.concat(
        dict(
            mgen=mgen_filt,
            drplt=drplt_filt_derep,
        ),
        dim='sample'
    )
    .select_samples_with_coverage(0.001)
    .random_sample(position=1000)
)

sf.plot.plot_depth(
    d.to_world(),
    col_colors_func=lambda w: xr.Dataset(dict(
        focal=w.sample.isin(['mgen_SS01009.m', 'mgen_SS01057.m']),
        drplt=w.sample.str.startswith('drplt'),
        depth=w.metagenotypes.mean_depth().pipe(np.log),
        mgen_entropy=w.metagenotypes.entropy('sample'),
    )),
)

sf.plot.plot_metagenotype(
    d.to_world(),
    col_colors_func=lambda w: xr.Dataset(dict(
        focal=w.sample.isin(['mgen_SS01009.m', 'mgen_SS01057.m']),
        drplt=w.sample.str.startswith('drplt'),
        depth=w.metagenotypes.mean_depth().pipe(np.log),
        mgen_entropy=w.metagenotypes.entropy('sample'),
    )),
)

In [None]:
g = sf.plot.plot_metagenotype(
    world=mgen_filt.select_samples_with_coverage(0.05).mlift('sel', position=mgen_ss.position).to_world(),
    matrix_func=lambda w: w.metagenotypes.alt_allele_fraction(pseudo=0.0).T,
    col_colors_func=None,
)

g.ax_heatmap.set_facecolor('darkgrey')

In [None]:
g = sf.plot.plot_metagenotype(
    world=mgen_filt.select_samples_with_coverage(0.05).mlift('sel', position=mgen_ss.position).to_world(),
#     matrix_func=lambda w: w.metagenotypes.alt_allele_fraction(pseudo=0.0).T,
    col_colors_func=None,
)

g.ax_heatmap.set_facecolor('darkgrey')

In [None]:
fig, ax = plt.subplots()

sf.plot.plot_metagenotype_frequency_spectrum(mgen_filt.to_world(), sample_list=['SS01009.m'], axs=ax, bins=np.linspace(0.5, 1.0, num=51))
# ax.set_yscale('log')
ax.set_ylim(0, 400)

In [None]:
mgen_filt.mean_depth().sel(sample=['SS01009.m', 'SS01060.m', 'SS01033.m'])

In [None]:
fig, ax = plt.subplots()

sf.plot.plot_metagenotype_frequency_spectrum(mgen_filt.to_world(), sample_list=['SS01060.m'], axs=ax, bins=np.linspace(0.5, 1.0, num=51))
# ax.set_yscale('log')
ax.set_ylim(0, 500)

In [None]:
fig, ax = plt.subplots()

sf.plot.plot_metagenotype_frequency_spectrum(mgen_filt.to_world(), sample_list=['SS01033.m'], axs=ax, bins=np.linspace(0.5, 1.0, num=51))
# ax.set_yscale('log')
ax.set_ylim(0, 500)

## Haplotyping

### Fitting

In [None]:
nstrain = 50
nposition = min(int(3e4), mgen_filt.sizes['position'])

est1 = (
# est, _ = (
    sf.workflow.fit_subsampled_metagenotype_collapse_strains_then_iteratively_refit_full_genotypes_no_missing(
        full_metagenotype_no_missing.model_structure,
#         mgen_ss,
        mgen_filt.random_sample(position=nposition),
        nstrain=nstrain,
        hyperparameters=dict(
            gamma_hyper=0.1,
#             delta_hyper_r=0.8,
#             delta_hyper_temp=0.01,
            rho_hyper=1.0,
            pi_hyper=1.0,
            epsilon_hyper_mode=0.01,
            epsilon_hyper_spread=1.5,
            alpha_hyper_hyper_mean=1000.0,
            alpha_hyper_hyper_scale=0.5,
            alpha_hyper_scale=2.0,
#             m_hyper_hyper_r_mean=10.,
#             m_hyper_hyper_r_scale=1.,
#             m_hyper_r_scale=1.,
        ),
        nposition=nposition,
        diss_thresh=0.05,
        frac_thresh=1e-3,
        stage2_hyperparameters=dict(gamma_hyper=1.0),
        device='cuda',
        dtype=torch.float64,
        quiet=False,
        estimation_kwargs=dict(
            lagA=10,
            lagB=100,
            opt=pyro.optim.Adamax({"lr": 1.0}, {"clip_norm": 100}),
            seed=2,
            jit=True,
        )
    )
)

In [None]:
e = est1

g = sf.plot.plot_community(
    e,
    col_linkage_func=lambda w: w.communities.linkage(dim="sample"),
    row_linkage_func=lambda w: w.genotypes.linkage(dim="strain"),
#     norm=None,
#     col_colors_func=lambda w: xr.Dataset(dict(
#         focal=w.sample.isin(['SS01009.m', 'SS01057.m']),
#         depth=w.metagenotypes.mean_depth().pipe(np.sqrt),
#         m_ent=w.metagenotypes.entropy('sample'),
#         c_ent=w.communities.entropy('sample'),
#         alpha=w.data.alpha.pipe(np.log),
#         epsilon=w.data.epsilon,
#         mu=w.data.mu.pipe(np.sqrt),
#     )),
    col_colors_func=None,
    row_colors_func=None,
    norm=mpl.colors.PowerNorm(1/2),
    pad_height=2.,
#     cheight=0.2,
)

In [None]:
e = est1

sf.plot.plot_genotype(
    e,
    col_linkage_func=lambda w: w.genotypes.linkage(dim="position"),
    row_linkage_func=lambda w: w.genotypes.linkage(dim="strain"),
    row_colors_func=None,
#     row_colors_func=lambda w: xr.Dataset(dict(
#         depth=(w.communities.data * w.metagenotypes.mean_depth()).mean("sample").pipe(np.log),
# #         focal=(w.communities.data.sel(sample='SS01009.m')).pipe(np.cbrt),
#         entropy=w.genotypes.entropy(),
#     )),
    transpose=True,
    scalex=1e-3,
)

In [None]:
e = est1

g = sf.plot.plot_metagenotype(
    e,
    row_linkage_func=lambda w: w.genotypes.linkage(dim="position"),
    col_linkage_func=lambda w: w.communities.linkage(dim="sample"),
#     row_linkage_func=lambda w: w.genotypes.linkage(dim="strain"),
#     row_colors_func=lambda w: xr.Dataset(dict(
#         depth=(w.communities.data * w.metagenotypes.mean_depth()).mean("sample").pipe(np.log),
# #         focal=(w.communities.data.sel(sample='SS01009.m')).pipe(np.cbrt),
#         entropy=w.genotypes.entropy(),
#     )),
    col_colors_func=None,
    scaley=1e-3,
)

In [None]:
e = est1

g = sf.plot.plot_metagenotype(
    e,
    matrix_func=lambda w: w.metagenotypes.alt_allele_fraction(pseudo=0.0).T,
    row_linkage_func=lambda w: w.genotypes.linkage(dim="position"),
    col_linkage_func=lambda w: w.communities.linkage(dim="sample"),
#     row_linkage_func=lambda w: w.genotypes.linkage(dim="strain"),
#     row_colors_func=lambda w: xr.Dataset(dict(
#         depth=(w.communities.data * w.metagenotypes.mean_depth()).mean("sample").pipe(np.log),
#         focal=(w.communities.data.sel(sample='SS01009.m')).pipe(np.cbrt),
#         entropy=w.genotypes.entropy(),
#     )),
    col_colors_func=None,
    scaley=1e-3,
)

g.ax_heatmap.set_facecolor('darkgrey')

In [None]:
# d = drplt_filt_derep.select_samples_with_coverage(0.1).data.

drplt_strain_type = sf.estimation.strain_cluster(
    (
        drplt_filt_derep
        .select_samples_with_coverage(0.05)
        .to_estimated_genotypes(pseudo=1.0)
        .to_world()
    ),
    thresh=0.1,
)
                     
drplt_strain_type_palette = lib.plot.construct_ordered_pallete(sorted(drplt_strain_type.unique()), cm='Spectral')
drplt_strain_type.value_counts()

In [None]:
w = sf.data.World(xr.Dataset(dict(
    metagenotypes=drplt_filt_derep.select_samples_with_coverage(0.05).data,
    genotypes=est1.genotypes.data
)))

g = sf.plot.plot_metagenotype(
    world=w.sel(position=mgen_ss.position),
    matrix_func=lambda w: w.metagenotypes.alt_allele_fraction(pseudo=0.0).T,
    row_linkage_func=lambda w: w.genotypes.linkage('position'),
    col_linkage_func=lambda w: w.metagenotypes.linkage('sample'),
#     col_colors_func=lambda w: drplt_strain_type.loc[w.sample].to_xarray().rename('this').to_dataset(),
    col_colors_func=None,
    col_colors=drplt_strain_type.map(drplt_strain_type_palette),
    scalex=0.2,
#     scaley=2e-3,
    tree_kws=dict(lw=2.),
    dwidth=1e-5,
    dheight=2.5,
)

g.ax_heatmap.set_facecolor('darkgrey')

In [None]:
w = sf.data.World(xr.Dataset(dict(
    metagenotypes=drplt_filt_derep.select_samples_with_coverage(0.05).data,
    genotypes=est1.genotypes.data
)))

g = sf.plot.plot_metagenotype(
    world=w.sel(position=mgen_ss.position),
    row_linkage_func=lambda w: w.genotypes.linkage('position'),
    col_linkage_func=lambda w: w.metagenotypes.linkage('sample'),
#     col_colors_func=lambda w: drplt_strain_type.loc[w.sample].to_xarray().rename('this').to_dataset(),
    col_colors_func=None,
    col_colors=drplt_strain_type.map(drplt_strain_type_palette),
    scalex=0.2,
#     scaley=2e-3,
    tree_kws=dict(lw=2.),
    dwidth=1e-5,
    dheight=2.5,
)

In [None]:
w = sf.data.World(xr.Dataset(dict(
    metagenotypes=drplt_filt_derep.select_samples_with_coverage(0.01).data,
    genotypes=est1.genotypes.data
)))

g = sf.plot.plot_metagenotype(
    world=w.isel(position=slice(0, 2000)),
    matrix_func=lambda w: w.metagenotypes.alt_allele_fraction(pseudo=0.0).T,
    row_linkage_func=lambda w: w.genotypes.linkage('position'),
    col_linkage_func=lambda w: w.metagenotypes.linkage('sample'),
#     col_colors_func=lambda w: drplt_strain_type.loc[w.sample].to_xarray().rename('this').to_dataset(),
    col_colors_func=None,
    scalex=0.2,
#     scaley=2e-3,
    tree_kws=dict(lw=2.),
    dwidth=1e-5,
    dheight=2.5,
#     isel=dict(position=slice(0, 2000)),
)

g.ax_heatmap.set_facecolor('darkgrey')

In [None]:
from scipy.cluster.hierarchy import linkage
from scipy.spatial.distance import pdist, squareform

w = sf.data.World(xr.Dataset(dict(
    metagenotypes=drplt_filt.select_samples_with_coverage(0.05).data,
    genotypes=est1.genotypes.data
)))

def depth_linkage(w, method="complete", optimal_ordering=False, **kwargs):
    d = w.metagenotypes.total_counts().to_series().unstack('position')
    cdmat = pdist(d, 'cosine')
    return linkage(cdmat, method=method, optimal_ordering=optimal_ordering, **kwargs)
    
    

g = sf.plot.plot_metagenotype(
    world=w.sel(position=mgen_ss.position),
#     matrix_func=lambda w: w.metagenotypes.alt_allele_fraction(pseudo=0.0).T,
    row_linkage_func=lambda w: w.genotypes.linkage('position'),
    col_linkage_func=depth_linkage,
#     col_colors_func=lambda w: drplt_strain_type.loc[w.sample].to_xarray().rename('this').to_dataset(),
    col_colors_func=None,
    col_colors=drplt_agg.map(drplt_agg_pal),
    scalex=0.2,
#     scaley=2e-3,
    tree_kws=dict(lw=2.),
    dwidth=1e-5,
    dheight=2.5,
)

g.ax_heatmap.set_facecolor('darkgrey')

In [None]:
w = sf.data.World(xr.Dataset(dict(
    metagenotypes=drplt_filt_derep.select_samples_with_coverage(0.05).data,
    genotypes=est1.genotypes.data
)))

ax, *_ = ordination_plot(
    w.metagenotypes.pdist(),
    ordin=lib.plot.nmds_ordination,
    meta=drplt_strain_type.to_frame(name='clust'),
    colorby='clust',
    color_palette=drplt_strain_type_palette,
    markersize_palette={'__none__': 40},
    ordin_kws={'is_dmat': True},
    fill_legend=False,
    scatter_kws=dict(lw=0.5, alpha=0.9),
)

ax.set_xlabel('PCo1')
ax.set_ylabel('PCo2')

In [None]:
e = est1
sample_list = [s for s in e.sample.values if s in ['SS01009.m']]

d = e.communities.sel(sample=sample_list).to_series()
strain_list1 = d[d > 0.01].reset_index().strain.unique()
print(strain_list1)
d.sort_values(ascending=False).head(10)

In [None]:
e = est1.sel(position=mgen_ss.position)

if len(strain_list1) > 1:
    sf.plot.plot_genotype(
        e,
        col_linkage_func=lambda w: w.genotypes.linkage('position'),
        row_linkage_func=None,
#         row_linkage_func=lambda w: w.genotypes.linkage('strain'),
#         scalex=1e-3,
        row_colors_func=None,
        transpose=True,
        dheight=1e-5,
        dwidth=1e-5,
#         isel=dict(strain=slice(0, 5)),
        scaley=0.5,
        tree_kws=dict(lw=2.),
        isel=dict(strain=[5, 10, 12, 14]),
    )

In [None]:
w = sf.data.Genotypes.concat(dict(
        drplt=drplt_filt_derep.select_samples_with_coverage(0.05).to_estimated_genotypes(),
        est=est1.genotypes.mlift('sel', strain=strain_list1),
), dim='strain')

ax, *_ = ordination_plot(
    w.pdist(),
    ordin=lib.plot.nmds_ordination,
    meta=pd.DataFrame(dict(
        clust=drplt_strain_type.rename(lambda s: 'drplt_' + s),
        is_est=w.strain.str.startswith('est_').to_series(),
    )).fillna(-1),
    colorby='clust',
    color_palette=drplt_strain_type_palette,
    markerby='is_est',
    marker_palette={True: '>', False: 'o'},
    zorderby='is_est',
    markersizeby='is_est',
    markersize_palette={True: 0, False: 40},
    ordin_kws={'is_dmat': True},
    fill_legend=False,
    scatter_kws=dict(lw=0.5, alpha=0.95),
)

ax.set_xlabel('PCo1')
ax.set_ylabel('PCo2')

In [None]:
w = sf.data.Genotypes.concat(dict(
        drplt=drplt_filt_derep.select_samples_with_coverage(0.05).to_estimated_genotypes(),
        est=est1.genotypes.mlift('sel', strain=strain_list1),
), dim='strain')

ax, *_ = ordination_plot(
    w.pdist(),
    ordin=lib.plot.nmds_ordination,
    meta=pd.DataFrame(dict(
        clust=drplt_strain_type.rename(lambda s: 'drplt_' + s),
        is_est=w.strain.str.startswith('est_').to_series(),
    )).fillna(-1),
    colorby='clust',
    color_palette=drplt_strain_type_palette,
    markerby='is_est',
    marker_palette={True: '>', False: 'o'},
    zorderby='is_est',
    markersizeby='is_est',
    markersize_palette={True: 60, False: 40},
    ordin_kws={'is_dmat': True},
    fill_legend=False,
    scatter_kws=dict(lw=0.5, alpha=0.95),
)

ax.set_xlabel('PCo1')
ax.set_ylabel('PCo2')

In [None]:
w = sf.data.Genotypes.concat(dict(
        drplt=drplt_filt_derep.select_samples_with_coverage(0.05).to_estimated_genotypes(),
        est=est1.genotypes.mlift('sel', strain=strain_list1),
), dim='strain')



In [None]:
a = drplt_filt_derep.to_estimated_genotypes()
b = est1.genotypes.drop_sel(strain=[-1])  # .sel(strain=strain_list1)

drplt_accuracy = pd.DataFrame(
    sf.math.genotype_cdist(a.values, b.values),
    index=a.strain,
    columns=b.strain,
).assign(
    best_strain=lambda x: x.idxmin(1),
    best_of_sample_strains=lambda x: x[strain_list1].idxmin(1),
    min_dist=lambda x: x.min(1),
    min_dist_sample_strains=lambda x: x[strain_list1].min(1),
    cvrg=(drplt_filt_derep.total_counts() > 0).mean("position"),
    median_dist=lambda x: x.median(1),
)
pal = construct_ordered_pallete(drplt_accuracy.best_strain.to_list(), cm='Paired')

In [None]:
for best_strain, d0 in drplt_accuracy.groupby('best_strain'):
    for best_of_sample_strains, d1 in d0.groupby('best_of_sample_strains'):
        plt.scatter('cvrg', 'min_dist_sample_strains', c=[pal[best_strain]], data=d1, label='__none__', alpha=0.7, lw={True: 1, False: 0}[best_strain in strain_list1], edgecolor='k')
        
for best_strain in sorted(drplt_accuracy.best_strain.unique()):
    plt.scatter([], [], c=[pal[best_strain]], label=best_strain, lw={True: 1, False: 0}[best_strain in strain_list1], edgecolor='k')

# for strain in drplt_accuracy.best_strain.unique():
#     label = str(strain)
#     if strain in strain_list1:
#         label += '*'
#     plt.scatter([], [], c=[pal[strain]], label=label)
plt.xlabel('Horizontal Coverage')
plt.ylabel('Strain Dissimilarity')
plt.ylim(0.0, 0.25)
plt.legend(ncol=2)

In [None]:
for best_strain, d0 in drplt_accuracy.groupby('best_strain'):
    for best_of_sample_strains, d1 in d0.groupby('best_of_sample_strains'):
        plt.scatter('cvrg', 'median_dist', c=[pal[best_strain]], data=d1, label='__none__', alpha=0.7, lw={True: 1, False: 0}[best_strain in strain_list1], edgecolor='k')
        
for best_strain in sorted(drplt_accuracy.best_strain.unique()):
    plt.scatter([], [], c=[pal[best_strain]], label=best_strain, lw={True: 1, False: 0}[best_strain in strain_list1], edgecolor='k')

# for strain in drplt_accuracy.best_strain.unique():
#     label = str(strain)
#     if strain in strain_list1:
#         label += '*'
#     plt.scatter([], [], c=[pal[strain]], label=label)
plt.xlabel('Horizontal Coverage')
plt.ylabel('Strain Dissimilarity')
plt.ylim(0.0, 0.25)
plt.legend(ncol=2)

In [None]:
from scipy.spatial.distance import pdist, squareform

x = np.linspace(0, 1, num=201).reshape((-1, 1))
d = pd.DataFrame(
    squareform(pdist(x, sf.math.genotype_dissimilarity)),
    index=x.squeeze(),
    columns=x.squeeze()
)

ax = sns.heatmap(d, xticklabels=len(x)//5, yticklabels=len(x)//5, cmap='gray_r')
ax.invert_yaxis()

### Evaluation

In [None]:
e = est1
sample_list = [s for s in e.sample.values if s in ['SS01057.m', 'SS01009.m']]

d = sf.data.Genotypes.concat(
        dict(
        est=e.genotypes.mlift('sel', strain=strain_list1),
        other=e.genotypes,
        mgen=(
            mgen_filt
            .mlift('sel', sample=sample_list)
            .mlift('sel', position=e.position)
            .to_estimated_genotypes()
        ),
        drplt=drplt_filt_derep.mlift('sel', position=e.position).to_estimated_genotypes(),
    ),
    dim='strain'
)

sf.plot.plot_genotype(
    d.to_world(),
    row_colors_func=None,
    col_colors_func=None,
    scalex=1e-3,
    transpose=True,
#     row_linkage_func=lambda w: w.genotypes.linkage(),
#     col_linkage_func=None,
#     col_cluster=False,
#     row_cluster=False,
)

In [None]:
e = est1
sample_list = [s for s in e.sample.values if s in ['SS01057.m', 'SS01009.m']]


d = sf.data.Genotypes.concat(dict(
    est=e.genotypes.mlift('sel', strain=strain_list1),
    othere=e.genotypes,
    mgen=mgen_filt.mlift('sel', sample=sample_list).mlift('sel', position=e.position).to_estimated_genotypes(),
    otherm=mgen_filt.mlift('sel', position=e.position).to_estimated_genotypes(),
    drplt=drplt_filt_derep.select_samples_with_coverage(0.05).mlift('sel', position=e.position).to_estimated_genotypes(),
), dim='strain')

def _assign_gtype(x):
    return x.index.to_series().str.split('_').apply(lambda x: x[0])

fig, ax = plt.subplots(figsize=(5, 5))

_, ordin, _ = ordination_plot(
    d.pdist(),
    meta=(
        pd.DataFrame([], index=d.strain)
        .assign(gtype=_assign_gtype)
        .assign(z=lambda x: (x.gtype.isin(['est', 'mgen']))
                .astype(int))
    ),
    ordin=nmds_ordination,
    ordin_kws=dict(is_dmat=True),
    colorby='gtype',
    zorderby='z',
#     edgestyleby='gtype',
#     edgestyle_palette={'drplt': 'None', 'est': '--', 'mgen': ':', 'other': ''}
    scatter_kws=dict(lw=0),
    ax=ax,
    
)

for name, d1 in ordin[ordin.gtype.isin(['est', 'mgen'])].iterrows():
    ax.annotate(name, xy=d1[['PC1', 'PC2']].to_list())
None

In [None]:
e = est1
sample_list = [s for s in e.sample.values if s in ['SS01057.m', 'SS01009.m']]


d = sf.data.Genotypes.concat(dict(
    est=e.genotypes.mlift('sel', strain=strain_list1),
#     othere=e.genotypes,
    mgen=mgen_filt.mlift('sel', sample=sample_list).mlift('sel', position=e.position).to_estimated_genotypes(),
#     otherm=mgen_filt.mlift('sel', position=e.position).to_estimated_genotypes(),
    drplt=drplt_filt_derep.select_samples_with_coverage(0.05).mlift('sel', position=e.position).to_estimated_genotypes(),
), dim='strain')

def _assign_gtype(x):
    return x.index.to_series().str.split('_').apply(lambda x: x[0])

fig, ax = plt.subplots(figsize=(5, 5))

_, ordin, _ = ordination_plot(
    d.pdist(),
    meta=(
        pd.DataFrame([], index=d.strain)
        .assign(gtype=_assign_gtype)
        .assign(z=lambda x: (x.gtype.isin(['est', 'mgen']))
                .astype(int))
    ),
    ordin=nmds_ordination,
    ordin_kws=dict(is_dmat=True),
    colorby='gtype',
    zorderby='z',
#     edgestyleby='gtype',
#     edgestyle_palette={'drplt': 'None', 'est': '--', 'mgen': ':', 'other': ''}
    scatter_kws=dict(lw=0),
    ax=ax,
    
)

for name, d1 in ordin[ordin.gtype.isin(['est', 'mgen'])].iterrows():
    ax.annotate(name, xy=d1[['PC1', 'PC2']].to_list())
None

In [None]:
e = est1
sample_list = [s for s in e.sample.values if s in ['SS01057.m', 'SS01009.m']]

d1 = sf.data.Genotypes.concat(
        dict(
#         est=e.genotypes.mlift('sel', strain=strain_list1),
        other=e.genotypes,
#         mgen=(
#             mgen_filt
#             .mlift('sel', sample=sample_list)
#             .mlift('sel', position=e.position)
#             .to_estimated_genotypes()
#         ),
    ),
    dim='strain'
)

d2 = sf.data.Genotypes.concat(
        dict(
        drplt=drplt_filt_derep.select_samples_with_coverage(0.05).mlift('sel', position=e.position).to_estimated_genotypes(),
    ),
    dim='strain'
)

diss = pd.DataFrame(sf.math.genotype_cdist(d1.values, d2.values), index=d1.strain, columns=d2.strain)


sns.clustermap(diss, vmax=0.05, row_linkage=d1.linkage(), col_linkage=d2.linkage())

In [None]:
e = est1
sample_list = [s for s in e.sample.values if s in ['SS01057.m', 'SS01009.m']]

d = sf.data.Genotypes.concat(
        dict(
#         est=e.genotypes.mlift('sel', strain=strain_list1),
        haplo=e.genotypes,
#         mgen=(
#             mgen_filt
#             .mlift('sel', sample=sample_list)
#             .mlift('sel', position=e.position)
#             .to_estimated_genotypes()
#         ),
        drplt=(
            drplt_filt_derep
            .mlift('sel', position=e.position)
            .select_samples_with_coverage(0.05)
            .to_estimated_genotypes()
        ),
    ),
    dim='strain'
)

sf.plot.plot_genotype(
    d.to_world(),
    row_colors_func=lambda w: xr.Dataset(dict(
        haplo=(
            w.strain.str.startswith('haplo').astype(int) +
            w.strain.isin([f'haplo_{i}' for i in strain_list1]).astype(int)
        ),
#         haplo=w.strain.str.startswith('haplo'),
    )),
    col_colors_func=None,
    scalex=1e-3,
    dwidth=5.0,
    dheight=0.1,
    cwidth=1.,
    transpose=True,
#     row_linkage_func=lambda w: w.genotypes.linkage(),
#     col_linkage_func=None,
#     col_cluster=False,
#     row_cluster=False,
)

In [None]:
nstrain = 50
nposition = min(int(3e4), mgen_filt.sizes['position'])

d = (
        sf.data.Metagenotypes.concat(
        dict(
            mgen=mgen_filt,
            drplt=drplt_filt_derep,
        ),
        dim='sample'
    )
    .select_samples_with_coverage(0.05)
)

est2 = (
    sf.workflow.fit_subsampled_metagenotype_collapse_strains_then_iteratively_refit_full_genotypes_no_missing(
        full_metagenotype_no_missing.model_structure,
        d.random_sample(position=nposition),
        nstrain=nstrain,
        hyperparameters=dict(
            gamma_hyper=0.1,
            rho_hyper=1.0,
            pi_hyper=1.0,
            epsilon_hyper_mode=0.01,
            epsilon_hyper_spread=1.5,
            alpha_hyper_hyper_mean=1000.0,
            alpha_hyper_hyper_scale=0.5,
            alpha_hyper_scale=2.0,
        ),
        nposition=nposition,
        diss_thresh=0.05,
        frac_thresh=1e-3,
        stage2_hyperparameters=dict(gamma_hyper=1.0),
        device='cuda',
        dtype=torch.float64,
        quiet=False,
        estimation_kwargs=dict(
            lagA=10,
            lagB=100,
            opt=pyro.optim.Adamax({"lr": 1.0}, {"clip_norm": 100}),
            seed=2,
            jit=True,
        )
    )
)

In [None]:
e = est2

g = sf.plot.plot_community(
    e,
#     col_linkage_func=lambda w: w.metagenotypes.linkage(dim="sample"),
    row_linkage_func=lambda w: w.genotypes.linkage(dim="strain"),
#     norm=None,
    col_colors_func=lambda w: xr.Dataset(dict(
        focal=w.sample.isin(['mgen_SS01009.m', 'mgen_SS01057.m']),
        drplt=w.sample.str.startswith('drplt_'),
        depth=w.metagenotypes.mean_depth().pipe(np.sqrt),
        m_ent=w.metagenotypes.entropy('sample'),
        c_ent=w.communities.entropy('sample'),
        alpha=w.data.alpha.pipe(np.log),
        epsilon=w.data.epsilon,
        mu=w.data.mu.pipe(np.sqrt),
    )),
    norm=mpl.colors.PowerNorm(1/2),
    pad_height=2.,
#     cheight=0.2,
)

In [None]:
e = est2

sf.plot.plot_genotype(
    e,
#     col_linkage_func=lambda w: w.metagenotypes.linkage(dim="position"),
#     row_linkage_func=lambda w: w.genotypes.linkage(dim="strain"),
    row_colors_func=lambda w: xr.Dataset(dict(
        depth=(w.communities.data * w.metagenotypes.mean_depth()).mean("sample").pipe(np.log),
#         focal=(w.communities.data.sel(sample='SS01009.m')).pipe(np.cbrt),
        entropy=w.genotypes.entropy(),
    )),
    scalex=1e-3,
)

In [None]:
e = est2
sample_list = [s for s in e.sample.values if s in ['mgen_SS01057.m', 'mgen_SS01009.m'] or s.startswith('drplt_')]

d = e.sel(sample=sample_list)
#strain_list1 = d[d > 0.01].reset_index().strain.unique()
#print(strain_list1)

sf.plot.plot_community(
    d,
    col_colors_func=lambda w: xr.Dataset(dict(
        focal=w.sample.isin(['mgen_SS01009.m', 'mgen_SS01057.m']),
#         drplt=w.sample.str.startswith('drplt_'),
#         depth=w.metagenotypes.mean_depth().pipe(np.sqrt),
#         m_ent=w.metagenotypes.entropy('sample'),
#         c_ent=w.communities.entropy('sample'),
#         alpha=w.data.alpha.pipe(np.log),
#         epsilon=w.data.epsilon,
#         mu=w.data.mu.pipe(np.sqrt),
    )),
)

In [None]:
e = est2
sample_list = [s for s in e.sample.values if s in ['mgen_SS01057.m', 'mgen_SS01009.m'] or s.startswith('drplt_')]

d = (
        sf.data.Metagenotypes.concat(
        dict(
            mgen=mgen_filt,
            drplt=drplt_filt_derep,
        ),
        dim='sample'
    )
    .select_samples_with_coverage(0.05)
    .mlift('sel', sample=sample_list)
)

sf.plot.plot_metagenotype(
    d.to_world(),
    col_colors_func=lambda w: xr.Dataset(dict(
        focal=w.sample.isin(['mgen_SS01009.m', 'mgen_SS01057.m']),
#         drplt=w.sample.str.startswith('drplt_'),
#         depth=w.metagenotypes.mean_depth().pipe(np.sqrt),
#         m_ent=w.metagenotypes.entropy('sample'),
#         c_ent=w.communities.entropy('sample'),
#         alpha=w.data.alpha.pipe(np.log),
#         epsilon=w.data.epsilon,
#         mu=w.data.mu.pipe(np.sqrt),
    )),
)