## Preamble

In [None]:
%load_ext autoreload
%autoreload 0

In [None]:
%autoreload

In [None]:
import sys
sys.path.append('/pollard/home/bsmith/Projects/haplo-benchmark/include/StrainFacts')

In [None]:
import xarray as xr
import sqlite3
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import warnings
import torch
import pyro
import scipy as sp

import lib.plot
from scipy.spatial.distance import pdist, squareform
from sklearn.metrics.pairwise import pairwise_distances
from sklearn.cluster import AgglomerativeClustering
from lib.pandas_util import idxwhere


import sfacts as sf

# from lib.project_style import color_palette, major_allele_frequency_bins
# from lib.project_data import metagenotype_db_to_xarray
# from lib.plot import ordination_plot, mds_ordination, nmds_ordination
# import lib.plot
# from lib.plot import construct_ordered_pallete
# from lib.pandas_util import idxwhere

## Load Data

In [None]:
fit = sf.data.World.load('data/zshi.sp-102506.metagenotype.filt-poly05-cvrg25-g500.fit-sfacts13-s500-g500-seed0.world.nc')
fit.sizes

In [None]:
ref = sf.data.Metagenotypes.load('data/gtprodb.sp-102506.genotype.nc').mlift('sel', position=fit.position).to_estimated_genotypes(pseudo=0)
ref.sizes

In [None]:
bins = np.linspace(0, 1, num=51)

plt.hist(fit.metagenotypes.to_estimated_genotypes(pseudo=1e-10).entropy(), bins=bins, alpha=0.5)
plt.hist(fit.genotypes.entropy(), bins=bins, alpha=0.5)
plt.hist(ref.entropy(), bins=bins, alpha=0.5)

plt.yscale('log')
None

In [None]:
bins = np.linspace(0, 1, num=51)

fig, axs = plt.subplots(2)
axs[0].hist(fit.genotypes.values.flatten(), bins=bins)
axs[1].hist(fit.communities.max("strain").values.flatten(), bins=bins)

for ax in axs:
    ax.set_yscale('log')
None

In [None]:
sf.plot.plot_genotype(fit, scaley=2e-2, scalex=1e-3, yticklabels=0)

In [None]:
sf.plot.plot_genotype(ref, scaley=1e-2, scalex=1e-3, yticklabels=0)

In [None]:
fit_genotypes_filt = fit.genotypes.mlift('sel', strain=fit.genotypes.entropy() < 0.25)
fit_genotypes_highent = fit.genotypes.mlift('sel', strain=fit.genotypes.entropy() > 0.25)

fit_genotypes_filt.sizes, fit_genotypes_highent.sizes

In [None]:
plt.hist(squareform(fit_genotypes_filt.pdist()), bins=np.linspace(0, 0.5, num=101))
plt.yscale('log')
None

In [None]:
g = sf.data.Genotypes.concat(dict(
    bad=fit_genotypes_highent,
    good=fit_genotypes_filt
), dim='strain')


sf.plot.plot_genotype(g, row_colors_func=lambda w: w.strain.str.startswith('good'), scaley=2e-2, scalex=1e-3, yticklabels=0)

In [None]:
g = sf.data.Genotypes.concat(dict(
    ref=ref,
    fit=fit_genotypes_filt,
    ent=fit_genotypes_highent,
), dim='strain')


sf.plot.plot_genotype(g, row_colors_func=lambda w: w.strain.str.startswith('fit'), scaley=2e-2, scalex=1e-3, yticklabels=0)

In [None]:
g = sf.data.Genotypes.concat(dict(
    ref=ref,
    fit=fit_genotypes_filt,
), dim='strain')


sf.plot.plot_genotype(g, row_colors_func=lambda w: w.strain.str.startswith('fit'), scaley=2e-2, scalex=1e-3, yticklabels=0)

In [None]:
sample_meta = pd.read_table('raw/shi2019s13.tsv').set_index('NCBI Accession Number')

sample_meta.groupby(['Continent', 'Country', 'Study']).apply(len)

In [None]:
dominant_strain = fit.communities.data.argmax("strain").to_series()
top_strains = list(dominant_strain.value_counts().sort_values(ascending=False).head(20).index)

d = (
    dominant_strain
    .to_frame(name='strain')
    .join(sample_meta, how='inner')
    .groupby(['Continent', 'Country', 'Study', 'strain'])
    .apply(len)
    .unstack('strain', fill_value=0)
    .apply(lambda x: x / x.sum(), axis=1)
    .assign(other=lambda x: 1 - x[top_strains].sum(1))
    [top_strains + ['other']]
    [sample_meta.groupby(['Continent', 'Country', 'Study']).apply(len) > 5]
)


lib.plot.construct_ordered_pallete(top_strains)

ax = (
    d
    .plot
    .bar(
        stacked=True, color=lib.plot.construct_ordered_pallete(top_strains, cm='tab20c', other='whitesmoke'),
        figsize=(10, 5)
    )
)
ax.legend(bbox_to_anchor=(1, 1))

In [None]:
d = pd.DataFrame(dict(
    total=(fit.communities.data * fit.data.mu).sum("sample"),
    maximum=(fit.communities.data * fit.data.mu).max("sample"),
))

sns.jointplot(y='total', x='maximum', data=np.log10(d))

In [None]:
plt.hist((fit.communities.data * fit.data.mu).sum("sample").to_series().apply(np.log10), bins=np.linspace(0, 4, num=101))
plt.hist((fit.communities.data * fit.data.mu).max("sample").to_series().apply(np.log10), bins=np.linspace(0, 4, num=101))

plt.yscale('log')

In [None]:
dominant_strain = fit.communities.data.argmax("strain").to_series()
top_strains = list(dominant_strain.value_counts().sort_values(ascending=False).head(20).index)

d0 = (
    dominant_strain
    .to_frame(name='strain')
    .join(sample_meta, how='inner')
    .groupby(['Continent', 'Country', 'Study', 'strain'])
    .apply(len)
)
    
d1 = (
    d0
    .unstack('strain', fill_value=0)
    .apply(lambda x: x / x.sum(), axis=1)
    .assign(other=lambda x: 1 - x[top_strains].sum(1))
    [top_strains + ['other']]
    [(d0.groupby(['Continent', 'Country', 'Study']).sum() >= 5)]
)


lib.plot.construct_ordered_pallete(top_strains)

ax = (
    d1
    .plot
    .bar(
        stacked=True, color=lib.plot.construct_ordered_pallete(top_strains, cm='tab20c', other='whitesmoke'),
        figsize=(10, 5)
    )
)
ax.legend(bbox_to_anchor=(1, 1))

In [None]:
dominant_strain = fit.communities.data.argmax("strain").to_series()
# top_strains = list(dominant_strain.value_counts().sort_values(ascending=False).head(20).index)

d0 = (
    dominant_strain
    .to_frame(name='strain')
    .join(sample_meta, how='inner')
    .groupby(['Continent', 'Country', 'Study', 'strain'])
    .apply(len)
)
    
d1 = (
    d0
    .unstack('strain', fill_value=0)
    .apply(lambda x: x / x.sum(), axis=1)
#     .assign(other=lambda x: 1 - x[top_strains].sum(1))
#     [top_strains + ['other']]
    [(d0.groupby(['Continent', 'Country', 'Study']).sum() >= 5)]
    .reindex(fit.strain, axis='columns')
    .fillna(0)
)


sns.clustermap(
    d1,
    norm=mpl.colors.PowerNorm(1/4),
    col_linkage=fit.genotypes.linkage(),
    metric='cosine',
    cmap='pink_r',
)

In [None]:
clust = pd.Series(
    AgglomerativeClustering(
        distance_threshold=0.15, n_clusters=None, affinity='precomputed', linkage='complete'
    ).fit_predict(fit.genotypes.pdist()),
    index=fit.strain,
)


agg_communities = fit.communities.to_series().unstack().groupby(clust, axis='columns').sum()

clust.value_counts()

In [None]:
dominant_strain = agg_communities.idxmax(1)
top_strains = list(dominant_strain.value_counts().sort_values(ascending=False).head(10).index)

d = (
    dominant_strain
    .to_frame(name='strain')
    .join(sample_meta, how='inner')
    .groupby(['Continent', 'Country', 'Study', 'strain'])
    .apply(len)
    .unstack('strain', fill_value=0)
    .apply(lambda x: x / x.sum(), axis=1)
    .assign(other=lambda x: 1 - x[top_strains].sum(1))
    [top_strains + ['other']]
    [sample_meta.groupby(['Continent', 'Country', 'Study']).apply(len) > 5]
)


lib.plot.construct_ordered_pallete(top_strains)

ax = (
    d
    .plot
    .bar(
        stacked=True, color=lib.plot.construct_ordered_pallete(top_strains, cm='tab20c', other='whitesmoke'),
        figsize=(10, 5)
    )
)
ax.legend(bbox_to_anchor=(1, 1))

In [None]:
dominant_strain = agg_communities.idxmax(1)
top_strains = list(dominant_strain.value_counts().sort_values(ascending=False).head(20).index)

d0 = (
    dominant_strain
    .to_frame(name='strain')
    .join(sample_meta, how='inner')
    .groupby(['Continent', 'Country', 'Study', 'strain'])
    .apply(len)
)
    
d1 = (
    d0
    .unstack('strain', fill_value=0)
    .apply(lambda x: x / x.sum(), axis=1)
    .assign(other=lambda x: 1 - x[top_strains].sum(1))
    [top_strains + ['other']]
    [(d0.groupby(['Continent', 'Country', 'Study']).sum() >= 5)]
)


lib.plot.construct_ordered_pallete(top_strains)

ax = (
    d1
    .plot
    .bar(
        stacked=True, color=lib.plot.construct_ordered_pallete(top_strains, cm='tab20c', other='whitesmoke'),
        figsize=(10, 5)
    )
)
ax.legend(bbox_to_anchor=(1, 1))

In [None]:
dominant_strain = agg_communities.idxmax(1)
# top_strains = list(dominant_strain.value_counts().sort_values(ascending=False).head(20).index)

d0 = (
    dominant_strain
    .to_frame(name='strain')
    .join(sample_meta, how='inner')
    .groupby(['Continent', 'Country', 'Study', 'strain'])
    .apply(len)
)
    
d1 = (
    d0
    .unstack('strain', fill_value=0)
    .apply(lambda x: x / x.sum(), axis=1)
    .assign(other=lambda x: 1 - x[top_strains].sum(1))
#     [top_strains + ['other']]
    [(d0.groupby(['Continent', 'Country', 'Study']).sum() >= 5)]
)


sns.clustermap(d1, norm=mpl.colors.PowerNorm(1/2))