In [None]:
%load_ext autoreload
%autoreload 0

In [None]:
%autoreload

In [None]:
import sys
sys.path.append('/pollard/home/bsmith/Projects/haplo-benchmark/include/StrainFacts')

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
from lib.pandas_util import idxwhere
import sfacts as sf
import numpy as np
import xarray as xr
import seaborn as sns
from scipy.spatial.distance import cdist
import lib.plot

In [None]:
focal_species = '104345'
# mgen_path = f'data/ucfmt.sp-{focal_species}.metagenotype.filt-poly05-cvrg05.nc'
scg_path = f'data/ucfmt.sp-{focal_species}.derep.genotype.nc'
fit_path = f'data/ucfmt.sp-{focal_species}.metagenotype.filt-poly05-cvrg05.fit-sfacts42-s30-g5000-seed0.refit-sfacts41-g10000-seed0.world.nc'

scg_to_sample_path = f'data/ucfmt.sp-{focal_species}.derep.barcode_to_sample.tsv'
library_to_sample_path = 'data/ucfmt.barcode_to_sample.tsv'
threshold = 0.01
pseudo = 1e-10
scg_cvrg_thresh = 0.05

scg_to_sample = pd.read_table(scg_to_sample_path, names=['scg', 'sample_id'], index_col='scg')
library_to_sample = pd.read_table(library_to_sample_path).rename(columns={'barcode': 'mgen'}).set_index('mgen')

# mgen = sf.data.Metagenotypes.load(mgen_path)
drplt = sf.data.Metagenotypes.load(scg_path)
inference = sf.data.World.load(fit_path)
inference.data['p'] = inference.data['communities'] @ inference.data['genotypes']
inference.data['m'] = inference.metagenotypes.total_counts()

inferred_community = inference.communities
shared_position = list(set(inference.position.values) & set(drplt.position.values))

consensus = inference.metagenotypes.to_estimated_genotypes(pseudo=pseudo).mlift('sel', position=shared_position)
scg = drplt.select_samples_with_coverage(scg_cvrg_thresh).to_estimated_genotypes(pseudo=pseudo).mlift('sel', position=shared_position)
inferred_genotype = inference.genotypes.mlift('sel', position=shared_position)

In [None]:
inference.sizes

In [None]:
a = sf.evaluation.metagenotype_error2(inference, discretized=False)
b = sf.evaluation.metagenotype_error2(inference, discretized=True)

(
    a[0],
    b[0],
#     a[1]['SS01009.m'],
#     b[1]['SS01009.m'],
    a[1]['SS01057.m'],
    b[1]['SS01057.m'],
)

In [None]:
focal_sample = 'SS01057'

focal_scg = scg.mlift('sel', strain=scg.strain.isin(idxwhere(scg_to_sample.sample_id == focal_sample)))
focal_mgen = consensus.mlift('sel', strain=consensus.strain.isin(idxwhere(library_to_sample.sample_id == focal_sample)))
focal_comm = inferred_community.mlift('sel', sample=inferred_community.sample.isin(idxwhere(library_to_sample.sample_id == focal_sample)))

focal_strains = idxwhere((focal_comm.data > threshold).any("sample").to_series())
focal_geno = inferred_genotype.mlift('sel', strain=focal_strains)

In [None]:
scg_to_focal_mgen_fdist = sf.match_genotypes(focal_scg.to_world(), focal_mgen.to_world())[1]
scg_to_focal_strain_fdist = sf.match_genotypes(focal_scg.to_world(), focal_geno.to_world())[1]
scg_to_focal_mgen_ddist = sf.match_genotypes(focal_scg.to_world(), focal_mgen.discretized().to_world())[1]
scg_to_focal_strain_ddist = sf.match_genotypes(focal_scg.to_world(), focal_geno.discretized().to_world())[1]
scg_to_focal_mgen_adist = sf.match_genotypes(focal_scg.to_world(), focal_mgen.to_world(), cdist=lambda x, y: sf.math.genotype_cdist(x, y, q=1))[1]
scg_to_focal_strain_adist = sf.match_genotypes(focal_scg.to_world(), focal_geno.to_world(), cdist=lambda x, y: sf.math.genotype_cdist(x, y, q=1))[1]
scg_to_focal_mgen_bdist = sf.match_genotypes(focal_scg.to_world(), focal_mgen.to_world(), cdist=lambda x, y: sf.math.genotype_cdist(x, y, q=3))[1]
scg_to_focal_strain_bdist = sf.match_genotypes(focal_scg.to_world(), focal_geno.to_world(), cdist=lambda x, y: sf.math.genotype_cdist(x, y, q=3))[1]
scg_to_focal_mgen_edist = sf.match_genotypes(focal_scg.to_world(), focal_mgen.to_world(), cdist=lambda x, y: sf.math.genotype_cdist(x, y, q=4))[1]
scg_to_focal_strain_edist = sf.match_genotypes(focal_scg.to_world(), focal_geno.to_world(), cdist=lambda x, y: sf.math.genotype_cdist(x, y, q=4))[1]

In [None]:
plt.scatter(scg_to_focal_mgen_ddist, scg_to_focal_strain_ddist, label='disc')
plt.scatter(scg_to_focal_mgen_cdist, scg_to_focal_strain_cdist, label='1')
plt.scatter(scg_to_focal_mgen_fdist, scg_to_focal_strain_fdist, label='2')
plt.scatter(scg_to_focal_mgen_bdist, scg_to_focal_strain_bdist, label='3')
plt.scatter(scg_to_focal_mgen_edist, scg_to_focal_strain_edist, label='4')



# plt.yscale('log')
# plt.xscale('log')
plt.legend(bbox_to_anchor=(1, 1))

plt.plot([0, 0.3], [0, 0.3])

In [None]:
top = 0.3
step = 0.005
bins = np.linspace(0, 0.3, num=int(np.ceil(top / step)))
plt.hist(scg_to_focal_mgen_fdist, bins=bins, alpha=0.4, color='blue', histtype='stepfilled')
plt.hist(scg_to_focal_mgen_ddist, bins=bins, alpha=0.8, color='blue', histtype='step', linestyle='-', lw=2)
# plt.hist(scg_to_focal_mgen_cdist, bins=bins, alpha=0.8, color='blue', histtype='step', linestyle='--', lw=2)

plt.hist(scg_to_focal_strain_fdist, bins=bins, alpha=0.4, color='red', histtype='stepfilled')
plt.hist(scg_to_focal_strain_ddist, bins=bins, alpha=0.8, color='red', histtype='step', linestyle='-', lw=2)
# plt.hist(scg_to_focal_strain_cdist, bins=bins, alpha=0.8, color='red', histtype='step', linestyle='--', lw=2)



# plt.hist(scg_to_focal_strain_fdist, bins=bins, alpha=0.6)
None

In [None]:
together = sf.data.Genotypes.concat(dict(
    g=focal_scg,
    m=focal_mgen,
    s=focal_geno,
), dim='strain')

sf.plot_genotype(
    together,
    transpose=True, scalex=1e-3)

In [None]:
sf.plot_community(
    inference,
    col_linkage_func=lambda w: w.communities.linkage("sample"),
    row_linkage_func=lambda w: w.genotypes.linkage("strain"),
    col_colors_func=lambda w: xr.Dataset(dict(
        focal=w.sample.str.startswith(f'{focal_sample}'),
        m_entropy=w.metagenotypes.entropy(),
        c_entropy=w.communities.entropy(),
    )),
)
sf.plot.plot_metagenotype2(
    inference.random_sample(position=1000),
    col_linkage_func=lambda w: w.communities.linkage("sample"),
    col_colors_func=lambda w: xr.Dataset(dict(
        focal=w.sample.str.startswith(f'{focal_sample}'),
        m_entropy=w.metagenotypes.entropy(),
#         alpha=w.data.alpha,
    )),
)

In [None]:
plt.scatter(inference.metagenotypes.entropy(), inference.communities.entropy())

In [None]:
sf.plot.plot_genotype(
    inference.random_sample(position=1000),
)

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))

ax, ordin, *_ = lib.plot.ordination_plot(
    together.pdist(),
    ordin=lib.plot.nmds_ordination,
    meta=pd.DataFrame(dict(
        t=together.strain.str[0],
    ), index=together.strain).fillna(-1),
    colorby='t',
#     color_palette=drplt_ucfmt_104345_strain_type_palette,
#     markerby='is_est',
#     marker_palette={True: '>', False: 'o'},
#     zorderby='is_est',
#     markersizeby='is_est',
#     markersize_palette={True: 60, False: 40},
    ordin_kws={'is_dmat': True,},
#     fill_legend=False,
    scatter_kws=dict(lw=0.5, alpha=0.5),
    ax=ax
)

ax.set_xlabel('PCo1')
ax.set_ylabel('PCo2')

ordin['gtype'] = ordin.index.to_series().str.split('_').apply(lambda x: x[0])
for name, d1 in ordin[ordin.gtype.isin(['s', 'm'])].iterrows():
    ax.annotate(name, xy=d1[['PC1', 'PC2']].to_list())
None

In [None]:
fig, ax = plt.subplots()

sf.plot.plot_metagenotype_frequency_spectrum(inference, sample_list=['SS01009.m'], axs=ax, bins=np.linspace(0.5, 1.0, num=51), show_predict=True)
# ax.set_yscale('log')
ax.set_ylim(0, 400)