In [None]:
import os as _os

_os.chdir('..')

In [None]:
%load_ext autoreload
%autoreload 0

In [None]:
%autoreload

In [None]:
import sfacts as sf

In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
import scipy as sp
import matplotlib as mpl
import scipy as sp
from operator import eq
from itertools import cycle


In [None]:
def expected_sample_entropy(w, discretized=False):
    if discretized:
        gen = w.genotype.discretized().data
    else:
        gen = w.genotype.data
        
    com = w.community.data
    depth = w.metagenotype.total_counts()
        
    return ((sf.math.binary_entropy(com @ gen) * depth).sum("position") / depth.sum("position")).rename("entropy")

def max_strain_depth(w):
    return (w.community.data * w.metagenotype.mean_depth()).max('sample').rename('depth')

def total_strain_depth(w):
    return (w.community.data * w.metagenotype.mean_depth()).sum('sample').rename('depth')

In [None]:
mgen = pd.read_table('meta/mgen.tsv', index_col='library_id')
preparation = pd.read_table('meta/preparation.tsv', index_col='preparation_id')
stool = pd.read_table('meta/stool.tsv', index_col='stool_id')
visit = pd.read_table('meta/visit.tsv', index_col='visit_id')
subject = pd.read_table('meta/subject.tsv', index_col='subject_id')

meta_all = (
    mgen
    .join(preparation.drop(columns='library_type'), on='preparation_id')
    .join(stool, on='stool_id')
    .join(visit, on='visit_id', rsuffix='_')
    .join(subject, on='subject_id')
)

assert not any(meta_all.subject_id.isna())

# meta.columns

In [None]:
species_id = '100022'

In [None]:
metagenotype_stem = f'data/sp-{species_id}.hmp2.a.r.proc.gtpro.filt-poly05-cvrg05'
metagenotype = sf.Metagenotype.load(f'{metagenotype_stem}.mgen.nc')
world_path = f'{metagenotype_stem}.ss-g10000-block0-seed0.fit-sfacts11-s75-seed0.world.nc'
world = sf.World.load(world_path)
print(world_path)

meta = meta_all.loc[world.sample].sort_values(['subject_id', 'visit_date'])
metagenotype = metagenotype.sel(sample=meta.index)
world = world.sel(sample=meta.index)
world_collapse = world.collapse_strains(thresh=0.05, discretized=True)

same_subject = sp.spatial.distance.pdist(meta.subject_id.values.reshape((-1, 1)), metric=eq).astype(bool)

n_position_ss = min(world.sizes['position'], 1000)
w_ss = world.random_sample(position=n_position_ss).sel(strain=world.community.max("sample") > 0.01)
w_ssc = world_collapse.sel(position=w_ss.position).sel(strain=world_collapse.community.max("sample") > 0.01)

print(metagenotype.sizes['sample'])
print(metagenotype.sizes['position'], world.sizes['position'])
print(world.sizes['strain'], w_ss.sizes['strain'], w_ssc.sizes['strain'])

In [None]:
sf.evaluation.metagenotype_error2(world, discretized=True)

In [None]:
# plt.scatter(world.data.rho, world.community.data.max('sample'))
plt.scatter(world.data.rho, world.community.data.mean('sample'))

In [None]:
_uf = w_ss.unifrac_pdist(discretized=True)
_mg = w_ss.metagenotype.pdist()

d = pd.DataFrame(dict(x=sp.spatial.distance.squareform(_mg), y=sp.spatial.distance.squareform(_uf)))
sns.jointplot(x='x', y='y', data=d, kind='hex', norm=mpl.colors.PowerNorm(1/2))
sp.stats.pearsonr(d.x, d.y)

In [None]:
from itertools import cycle

subject_index_mod = {k: v for k, v in zip(meta.subject_id.unique(), cycle(range(20)))}
diagnosis_map = dict(zip(meta.ibd_diagnosis.unique(), range(100)))
site_map = dict(zip(meta.site.unique(), range(100)))



sf.plot.plot_community(
    w_ss.sel(sample=meta.index),
    col_colors_func=lambda w: xr.Dataset(dict(
        site=meta.loc[w.sample].site.map(site_map),
        diagnosis=meta.loc[w.sample].ibd_diagnosis.map(diagnosis_map),
        subject=meta.loc[w.sample].subject_id.map(subject_index_mod),
    )),
    row_col_annotation_cmap=mpl.cm.tab20,
    row_linkage_func=lambda w: w.genotype.discretized().linkage(dim="strain"),
    scalex=0.05,
    xticklabels=0,
    col_cluster=False,
    # col_linkage_func=lambda w: w.unifrac_linkage(),
    # norm=mpl.colors.PowerNorm(1),
)

In [None]:
subject_index_mod = {k: v for k, v in zip(meta.subject_id.unique(), cycle(range(20)))}

sf.plot.plot_metagenotype(
    w_ss.sel(sample=meta.index),
    col_colors_func=lambda w: xr.Dataset(dict(
        site=meta.loc[w.sample].site.map(site_map),
        diagnosis=meta.loc[w.sample].ibd_diagnosis.map(diagnosis_map),
        subject=meta.loc[w.sample].subject_id.map(subject_index_mod),
    )),
    row_col_annotation_cmap=mpl.cm.tab20,
    scalex=0.05,
    xticklabels=0,
    col_cluster=False,
    # col_linkage_func=lambda w: w.unifrac_linkage(),
)

In [None]:
sf.plot.plot_dominance(
    w_ss.sel(sample=meta.index),
    col_colors_func=lambda w: xr.Dataset(dict(
        site=meta.loc[w.sample].site.map(site_map),
        diagnosis=meta.loc[w.sample].ibd_diagnosis.map(diagnosis_map),
        subject=meta.loc[w.sample].subject_id.map(subject_index_mod),
    )),
    row_col_annotation_cmap=mpl.cm.tab20,
    row_linkage_func=lambda w: w.metagenotype.linkage(dim="position"),
    scalex=0.05,
    xticklabels=0,
    col_cluster=False,
    # col_linkage_func=lambda w: w.unifrac_linkage(),
)

In [None]:
sf.plot_genotype(
    w_ss,
    row_colors_func=lambda w: xr.Dataset(dict(
        _=np.log(total_strain_depth(w)),

    )),
    row_linkage_func=lambda w: w.genotype.discretized().linkage(dim="strain"),
    col_linkage_func=lambda w: w.metagenotype.linkage(dim="position"),
    scalex=0.05,
    xticklabels=0,
)

In [None]:
plt.hist(world.genotype.to_series().values, bins=np.linspace(0, 1, num=501))
plt.yscale('log')
None

In [None]:
sf.plot_community(
    w_ssc,
    col_colors_func=lambda w: xr.Dataset(dict(
        subject=meta.loc[w.sample].subject_id.map(subject_index_mod),
    )),
    row_col_annotation_cmap=mpl.cm.tab20,
    col_linkage_func=lambda w: w.metagenotype.linkage(dim="sample"),
    row_linkage_func=lambda w: w.genotype.discretized().linkage(dim="strain"),
    scalex=0.05,
    xticklabels=0,
)

In [None]:
sf.plot.plot_metagenotype(
    w_ssc,
    col_colors_func=lambda w: xr.Dataset(dict(
        subject=meta.loc[w.sample].subject_id.map(subject_index_mod),
    )),
    row_col_annotation_cmap=mpl.cm.tab20,
    col_linkage_func=lambda w: w.metagenotype.linkage(dim="sample"),
    scalex=0.05,
    xticklabels=0,
)

In [None]:
sf.plot.plot_dominance(
    w_ssc,
    matrix_func=lambda w: w.metagenotype.dominant_allele_fraction(pseudo=0).T,
    col_colors_func=lambda w: xr.Dataset(dict(
        site=meta.loc[w.sample].site.map(site_map),
        diagnosis=meta.loc[w.sample].ibd_diagnosis.map(diagnosis_map),
        subject=meta.loc[w.sample].subject_id.map(subject_index_mod),
    )),
    row_col_annotation_cmap=mpl.cm.tab20,
    row_linkage_func=lambda w: w.metagenotype.linkage(dim="position"),
    col_linkage_func=lambda w: w.metagenotype.linkage(dim="sample"),
    scalex=0.05,
    xticklabels=0,
)

In [None]:
sf.plot.plot_depth(
    w_ss,
    col_colors_func=lambda w: xr.Dataset(dict(
        site=meta.loc[w.sample].site.map(site_map),
        diagnosis=meta.loc[w.sample].ibd_diagnosis.map(diagnosis_map),
        subject=meta.loc[w.sample].subject_id.map(subject_index_mod),
    )),
    row_col_annotation_cmap=mpl.cm.tab20,
    row_linkage_func=lambda w: w.metagenotype.linkage(dim="position"),
    col_linkage_func=lambda w: w.metagenotype.linkage(dim="sample"),
    scalex=0.05,
    xticklabels=0,
)

In [None]:
sf.plot_community(
    w_ssc,
    col_colors_func=lambda w: xr.Dataset(dict(
        subject=meta.loc[w.sample].subject_id.map(subject_index_mod),
    )),
    row_col_annotation_cmap=mpl.cm.tab20,
    col_linkage_func=lambda w: w.unifrac_linkage(),
    row_linkage_func=lambda w: w.genotype.discretized().linkage(dim="strain"),
    scalex=0.05,
    xticklabels=0,
)

In [None]:
sf.plot.plot_metagenotype(
    w_ssc,
    col_colors_func=lambda w: xr.Dataset(dict(
        subject=meta.loc[w.sample].subject_id.map(subject_index_mod),
    )),
    row_col_annotation_cmap=mpl.cm.tab20,
    col_linkage_func=lambda w: w.unifrac_linkage(),
    scalex=0.05,
    xticklabels=0,
)

In [None]:
sf.plot.plot_dominance(
    w_ssc,
    matrix_func=lambda w: w.metagenotype.dominant_allele_fraction(pseudo=0).T,
    col_colors_func=lambda w: xr.Dataset(dict(
        site=meta.loc[w.sample].site.map(site_map),
        diagnosis=meta.loc[w.sample].ibd_diagnosis.map(diagnosis_map),
        subject=meta.loc[w.sample].subject_id.map(subject_index_mod),
    )),
    row_col_annotation_cmap=mpl.cm.tab20,
    row_linkage_func=lambda w: w.metagenotype.linkage(dim="position"),
    col_linkage_func=lambda w: w.unifrac_linkage(),
    scalex=0.05,
    xticklabels=0,
)

In [None]:
uf_dist = sp.spatial.distance.squareform(sf.unifrac.unifrac_pdist(world))
ufc_dist = sp.spatial.distance.squareform(sf.unifrac.unifrac_pdist(world_collapse))
bc_dist = sp.spatial.distance.squareform(world.community.pdist(dim='sample'))
bcc_dist = sp.spatial.distance.squareform(world_collapse.community.pdist(dim='sample'))
mg_dist = sp.spatial.distance.squareform(world.metagenotype.pdist(dim='sample'))
mg_cos_dist = sp.spatial.distance.squareform(world.metagenotype.cosine_pdist(dim='sample'))

# mg_all_dist = sp.spatial.distance.squareform(metagenotype.pdist(dim='sample'))
# mg_all_cos_dist = sp.spatial.distance.squareform(metagenotype.cosine_pdist(dim='sample'))


In [None]:
bins = np.linspace(0, 1, num=51)

all_dists = [uf_dist, bcc_dist, mg_dist, mg_cos_dist]

fig, axs = plt.subplots(2, 2, sharex=True, sharey=True)

for ax, dist in  zip(axs.flatten(), all_dists):
    fig = plt.figure()
    ax.hist(dist[~same_subject], bins=bins, density=True, alpha=0.5, label='trans')
    ax.hist(dist[same_subject], bins=bins, density=True, alpha=0.5, label='cis')
    ax.set_yscale('log')
    mwu, p = sp.stats.mannwhitneyu(dist[~same_subject], dist[same_subject])
    auc = mwu / (len(dist[~same_subject]) * len(dist[same_subject]))
    print(auc)
    
axs[0, 1].legend()

In [None]:
plt.scatter(mg_dist, uf_dist, s=0.1, alpha=0.2)
# sns.regplot(x=mg_dist, y=uf_dist, scatter=False, color='black', lowess=True)

In [None]:
plt.scatter(
    'mgen_entropy',
    'expect_entropy',
    c='depth',
    data=pd.DataFrame(dict(
        mgen_entropy=world.metagenotype.entropy(),
        expect_entropy=expected_sample_entropy(world, discretized=True),
        depth=world.metagenotype.mean_depth(),
        hcov=world.metagenotype.horizontal_coverage(),
    )),
    s=3,
    norm=mpl.colors.LogNorm(),
)
plt.colorbar()

# plt.yscale('log')
# plt.xscale('log')

In [None]:
d = pd.DataFrame(dict(
        mgen_entropy=world.metagenotype.entropy(),
        expect_entropy=expected_sample_entropy(world, discretized=True),
        depth=world.metagenotype.mean_depth(),
        hcov=world.metagenotype.horizontal_coverage(),
    ))

print('WMAE:', (np.abs(d.mgen_entropy - d.expect_entropy) * d.depth).sum() / d.depth.sum())
print('CORR:', sp.stats.pearsonr(np.sqrt(d.expect_entropy), np.sqrt(d.mgen_entropy))[0])
print('SPEAR:', sp.stats.spearmanr(np.sqrt(d.expect_entropy), np.sqrt(d.mgen_entropy))[0])

In [None]:
plt.scatter(
    'total_depth',
    'geno_entropy',
    c='max_depth',
    data=pd.DataFrame(dict(
        geno_entropy=w_ss.genotype.entropy(),
        max_depth=max_strain_depth(w_ss),
        total_depth=total_strain_depth(w_ss),
    )),
)
plt.colorbar()

In [None]:
plt.scatter(
    'total_depth',
    'geno_entropy',
    c='max_depth',
    data=pd.DataFrame(dict(
        geno_entropy=w_ssc.genotype.entropy(),
        max_depth=max_strain_depth(w_ssc),
        total_depth=total_strain_depth(w_ssc),
    )),
)
plt.colorbar()