# Preamble

In [None]:
%load_ext autoreload
%autoreload 0

In [None]:
import sys
sys.path.append('/pollard/data/projects/bsmith/sc-validate-haplotypes/include/StrainFacts')

## Imports

In [None]:
import xarray as xr
import sqlite3
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from lib.pandas_util import idxwhere
import warnings
import torch
import pyro
import scipy as sp

import sfacts as sf
from sfacts.model_zoo import (
    full_metagenotype,
    full_metagenotype_no_missing,
    full_metagenotype_pp_pi,
    full_metagenotype_dirichlet_rho,
    full_metagenotype_dirichlet_rho_no_missing,
    full_metagenotype_pp_pi_no_missing,
#     full_metagenotype_special_meta,
#     full_metagenotype_no_missing_special_meta,
)


from lib.project_style import color_palette, major_allele_frequency_bins
from lib.project_data import metagenotype_db_to_xarray

In [None]:
warnings.filterwarnings(
    "ignore",
    message="torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.",
    category=torch.jit.TracerWarning,
#     module="trace_elbo",  # FIXME: What is the correct regex for module?
#     lineno=5,
)

## Style

## Load Data

In [None]:
con = sqlite3.connect('data/core.a.proc.gtpro.2.denorm.db')

In [None]:
species_of_interest = '100199'

In [None]:
_mgen_all = pd.read_sql(f"""
SELECT
  lib_id
, species_position
, reference_tally
, alternative_tally
FROM snp_x_lib
JOIN lib USING (lib_id)
WHERE species_id = '{species_of_interest}'
  AND lib_type = 'metagenome'
""", con=con, index_col=['lib_id', 'species_position'])

mgen_all = sf.data.Metagenotypes(metagenotype_db_to_xarray(_mgen_all))

In [None]:
mgen_filt = (
    mgen_all
    .select_variable_positions(thresh=0.05)
    .select_samples_with_coverage(0.05)
)
mgen_ss = mgen_filt.random_sample(position=1000)

In [None]:
mgen_filt.sizes

In [None]:
_drplt_all = pd.read_sql(f"""
SELECT
  lib_id
, species_position
, reference_tally
, alternative_tally
FROM snp_x_lib
JOIN lib USING (lib_id)
WHERE species_id = '{species_of_interest}'
  AND lib_type = 'droplet'
""", con=con, index_col=['lib_id', 'species_position'])

drplt_all = sf.data.Metagenotypes(metagenotype_db_to_xarray(_drplt_all))

In [None]:
drplt_filt = drplt_all.mlift('reindex', position=mgen_filt.position, fill_value=0).select_samples_with_coverage(0.05)

In [None]:
d = (
        sf.data.Metagenotypes.concat(
        dict(
            mgen=mgen_filt,
            drplt=drplt_filt,
        ),
        dim='sample'
    )
    .select_samples_with_coverage(0.1)
    .random_sample(position=1000)
)

sf.plot.plot_depth(
    d.to_world(),
    col_colors_func=lambda w: xr.Dataset(dict(
        focal=w.sample.isin(['mgen_SS01009.m', 'mgen_SS01057.m']),
        drplt=w.sample.str.startswith('drplt'),
        depth=w.metagenotypes.mean_depth().pipe(np.log),
        mgen_entropy=w.metagenotypes.entropy('sample'),
    )),
)

sf.plot.plot_metagenotype(
    d.to_world(),
    col_colors_func=lambda w: xr.Dataset(dict(
        focal=w.sample.isin(['mgen_SS01009.m', 'mgen_SS01057.m']),
        drplt=w.sample.str.startswith('drplt'),
        depth=w.metagenotypes.mean_depth().pipe(np.log),
        mgen_entropy=w.metagenotypes.entropy('sample'),
    )),
)

In [None]:
nstrain = 200
nposition = min(int(5e3), mgen_filt.sizes['position'])

est1 = (
# est, _ = (
    sf.workflow.fit_subsampled_metagenotype_collapse_strains_then_iteratively_refit_full_genotypes(
        full_metagenotype_no_missing.model_structure,
#         mgen_ss,
        mgen_filt.random_sample(position=nposition),
        nstrain=nstrain,
        hyperparameters=dict(
            gamma_hyper=0.1,
#             delta_hyper_r=0.8,
#             delta_hyper_temp=0.01,
            rho_hyper=0.5,
            pi_hyper=1.0,
            epsilon_hyper_mode=0.01,
            epsilon_hyper_spread=1.5,
            alpha_hyper_hyper_mean=1000.0,
            alpha_hyper_hyper_scale=0.5,
            alpha_hyper_scale=2.0,
#             m_hyper_hyper_r_mean=10.,
#             m_hyper_hyper_r_scale=1.,
#             m_hyper_r_scale=1.,
        ),
        nposition=nposition,
        diss_thresh=0.05,
        frac_thresh=1e-3,
        stage2_hyperparameters=dict(gamma_hyper=1.0),
        device='cuda',
        dtype=torch.float64,
        quiet=False,
        estimation_kwargs=dict(
            lagA=10,
            lagB=100,
            opt=pyro.optim.Adamax({"lr": 1.0}, {"clip_norm": 100}),
            seed=2,
            jit=True,
        )
    )
)

In [None]:
e = est1

sf.plot.plot_genotype(
    e,
    col_linkage_func=lambda w: w.metagenotypes.linkage(dim="position"),
    row_linkage_func=lambda w: w.genotypes.linkage(dim="strain"),
    row_colors_func=lambda w: xr.Dataset(dict(
        depth=(w.communities.data * w.metagenotypes.mean_depth()).mean("sample").pipe(np.log),
        entropy=w.genotypes.entropy(),
    )),
)

In [None]:
e = est1

sf.plot.plot_community(
    e,
#     col_linkage_func=lambda w: w.metagenotypes.linkage(dim="sample"),
    row_linkage_func=lambda w: w.genotypes.linkage(dim="strain"),
#     norm=None,
    col_colors_func=lambda w: xr.Dataset(dict(
        focal=w.sample.isin(['SS01009.m', 'SS01057.m']),
        depth=w.metagenotypes.mean_depth().pipe(np.log),
        mgen_entropy=w.metagenotypes.entropy('sample'),
        comm_entropy=w.communities.entropy('sample'),
    )),
    norm=mpl.colors.PowerNorm(1/2),
)

In [None]:
e = est1

d = e.communities.sel(sample=['SS01009.m', 'SS01057.m']).to_series()
strain_list1 = d[d > 0.01].reset_index().strain.unique()
print(strain_list1)
d.sort_values(ascending=False).head(10)

In [None]:
e = est1

sf.plot.plot_genotype(
    e.genotypes.mlift('sel', strain=strain_list1).to_world()
)

In [None]:
e = est1

d = sf.data.Genotypes.concat(dict(
    est=e.genotypes.mlift('sel', strain=strain_list1),
    other=e.genotypes,
    mgen=mgen_filt.mlift('sel', sample=['SS01009.m', 'SS01057.m']).mlift('sel', position=e.position).to_estimated_genotypes(),
    drplt=drplt_filt.mlift('sel', position=e.position).to_estimated_genotypes(),
), dim='strain')

sf.plot.plot_genotype(
    d.to_world(),
    row_colors_func=None,
    col_colors_func=None,
#     row_linkage_func=None,
#     col_linkage_func=None,
#     col_cluster=False,
#     row_cluster=False,
)

In [None]:
from lib.plot import ordination_plot, mds_ordination, nmds_ordination

e = est1

d = sf.data.Genotypes.concat(dict(
    est=e.genotypes.mlift('sel', strain=strain_list1),
    other=e.genotypes,
    mgen=mgen_filt.mlift('sel', sample=['SS01009.m', 'SS01057.m']).mlift('sel', position=e.position).to_estimated_genotypes(),
    drplt=drplt_filt.mlift('sel', position=e.position).to_estimated_genotypes(),
), dim='strain')

def _assign_gtype(x):
    return x.index.to_series().str.split('_').apply(lambda x: x[0])

fig, ax = plt.subplots(figsize=(10, 10))

_, ordin, _ = ordination_plot(
    d.pdist(),
    meta=pd.DataFrame([], index=d.strain).assign(gtype=_assign_gtype).assign(z=lambda x: (x.gtype == 'est').astype(int)),
    ordin=nmds_ordination,
    ordin_kws=dict(is_dmat=True),
    colorby='gtype',
    zorderby='z',
#     edgestyleby='gtype',
#     edgestyle_palette={'drplt': 'None', 'est': '--', 'mgen': ':', 'other': ''}
    scatter_kws=dict(lw=0, alpha=0.7),
    ax=ax,
    
)

for name, d1 in ordin[ordin.gtype.isin(['est', 'mgen'])].iterrows():
    ax.annotate(name, xy=d1[['PC1', 'PC2']].to_list())
None