In [None]:
%load_ext autoreload

In [None]:
%autoreload

In [None]:
import os
os.chdir('..')
os.path.realpath(os.path.curdir)

In [None]:
import xarray as xr
from glob import glob
import pandas as pd
from lib.pandas_util import idxwhere
from sklearn.cluster import AgglomerativeClustering
import sfacts as sf
import matplotlib as mpl
import matplotlib.pyplot as plt
from lib.plot import construct_ordered_palette
from tqdm import tqdm
import numpy as np
import seaborn as sns
import lib.plot
import scipy.stats
import scipy as sp

In [None]:
group = 'hmp2'

In [None]:
mgen = pd.read_table('meta/hmp2/mgen.tsv', index_col='library_id')
prep = pd.read_table('meta/hmp2/preparation.tsv', index_col='preparation_id')
stool = pd.read_table('meta/hmp2/stool.tsv', index_col='stool_id')
subject = pd.read_table('meta/hmp2/subject.tsv', index_col='subject_id')

meta = mgen.join(prep, on='preparation_id', rsuffix='_').join(stool, on='stool_id').join(subject, on='subject_id')
assert meta.index.is_unique

In [None]:
species_taxonomy = pd.read_table('ref/gtpro/species_taxonomy_ext.tsv', names=['genome_id', 'species_id', 'taxonomy_string']).assign(species_id=lambda x: x.species_id.astype(str)).set_index('species_id')[['taxonomy_string']].assign(taxonomy_split=lambda x: x.taxonomy_string.str.split(';'))

for level_name, level_number in [('p__', 1), ('c__', 2), ('o__', 3), ('f__', 4), ('g__', 5), ('s__', 6)]:
    species_taxonomy = species_taxonomy.assign(**{level_name: species_taxonomy.taxonomy_split.apply(lambda x: x[level_number])}) 
species_taxonomy = species_taxonomy.drop(columns=['taxonomy_split'])

In [None]:
num_species_sample = {}
num_species_position = {}
species_list = []

for path in glob(f'data/sp-*.{group}.a.r.proc.gtpro.filt-poly05-cvrg10.mgen.pdist.nc'):
    species_id = path[len('data/sp-'):-len(f'.{group}.a.r.proc.gtpro.filt-poly05-cvrg10.mgen.pdist.nc')]
    sizes = xr.open_dataset(f'data/sp-{species_id}.{group}.a.r.proc.gtpro.filt-poly05-cvrg10.mgen.nc').sizes
    num_species_sample[species_id] = sizes['sample']
    num_species_position[species_id] = sizes['position']
    if (num_species_sample[species_id] > 100) and (num_species_position[species_id] > 100):
        species_list.append(species_id)

In [None]:
len(species_list), sorted(species_list)[-10:]

In [None]:
species_strain_label = {}
species_strain_counts = {}

phyla = species_taxonomy.loc[species_list].p__.unique()
palette = construct_ordered_palette(phyla, cm='tab20')

for species_id in tqdm(species_list):
    pdmat = xr.load_dataarray(f'data/sp-{species_id}.{group}.a.r.proc.gtpro.filt-poly05-cvrg10.mgen.pdist.nc')
    agg = pd.Series(AgglomerativeClustering(n_clusters=None, distance_threshold=0.02, affinity='precomputed', linkage='complete').fit(pdmat).labels_, index=pdmat.sampleA)
    species_strain_label[species_id] = agg.astype(str)
    species_strain_counts[species_id] = agg.value_counts()

In [None]:
species_data = pd.DataFrame(
    {s: (species_strain_counts[s].sum(), len(species_strain_counts[s])) for s in species_strain_counts},
    index=['num_samples', 'num_strains']
).T.assign(samples_per_strain=lambda x: x.num_samples / x.num_strains).join(species_taxonomy).sort_values('taxonomy_string')

taxa = species_data.p__.unique()
palette = construct_ordered_palette(taxa, cm='tab20')

plt.scatter('num_samples', 'num_strains', data=species_data, c=species_data.p__.map(palette), label='__none__')
plt.yscale('log')
plt.xscale('log')

for tax in species_data.p__.unique():
    plt.scatter([], [], label=tax, color=palette[tax])
plt.plot([0, 1e4], [0, 1e4], lw=1, linestyle='--', color='grey', zorder=0)
plt.legend(bbox_to_anchor=(1, 1))

In [None]:
def rarifaction(value_counts):
    x = pd.Series(
        np.random.choice(
            value_counts.index,
            size=value_counts.sum(),
            replace=True,
            p=value_counts / value_counts.sum()
        )
    )
    return (~pd.Series.duplicated(x)).cumsum()

fig, ax = plt.subplots(figsize=(10, 10))

for species_id in species_list:
    ax.plot(rarifaction(species_strain_counts[species_id]), color=palette[species_taxonomy.loc[species_id].p__])

for p__ in phyla:
    ax.scatter([], [], label=p__, color=palette[p__])
    
ax.plot([0, 1e3], [0, 1e3], lw=1, linestyle='--', color='k')
ax.set_aspect(1)

ax.legend(bbox_to_anchor=(1, 1))

In [None]:
lib.plot.boxplot_with_points(x='p__', y='samples_per_strain', palette=palette, data=species_data[species_data.num_samples > 200])
lib.plot.rotate_xticklabels()

In [None]:
from scipy.spatial.distance import pdist, squareform

d = pd.DataFrame(species_strain_label).astype(float)
d = d.loc[d.notna().sum(1) > 10, d.notna().sum() > 100]

m = meta.loc[d.index]

num_strains = d.notna().sum(1)

def frac_shared_strains(x, y):
    "Fraction of all x strains shared by y where both have an assignment."
    return (x == y).sum() / sum(~np.isnan(x) & ~np.isnan(y))

def num_shared_strains(x, y):
    "Number of all x strains."
    return (x == y).sum()

# Values are in compressed_distance_matrix form with indexes species_strain_label.index
shared_strains_num = pdist(d, metric=num_shared_strains)
shared_strains_frac = pdist(d, metric=frac_shared_strains)

diff_subject_mask = pdist(m[['subject_id']], metric=lambda x, y: x != y).astype(bool)
diff_stool_mask = pdist(m[['stool_id']], metric=lambda x, y: x != y).astype(bool)
diff_site_mask = pdist(m[['site']], metric=lambda x, y: x != y).astype(bool)
diff_diagnosis_mask = pdist(m[['ibd_diagnosis']], metric=lambda x, y: x != y).astype(bool)
diff_has_ibd_mask = pdist(m[['ibd_diagnosis']] == 'nonIBD', metric=lambda x, y: x != y).astype(bool)

In [None]:
fig, axs = plt.subplots(5, 2, figsize=(10, 15))
for (var_title, diff_masking, same_masking), ax_row in zip(
    [
        ('subj', diff_subject_mask, ~diff_subject_mask),
        ('site', diff_site_mask & diff_subject_mask, ~diff_site_mask & diff_subject_mask),
        ('stool', diff_stool_mask & ~diff_subject_mask, ~diff_stool_mask & ~diff_subject_mask),
        ('diseased', diff_has_ibd_mask & diff_subject_mask, ~diff_has_ibd_mask & diff_subject_mask),
        ('form', diff_diagnosis_mask & diff_subject_mask, ~diff_diagnosis_mask & diff_subject_mask),
    ],
    axs
):
    for (val_title, shared_strains, bins), ax in zip(
        [
            ('num', shared_strains_num, np.arange(50)),
            ('frac', shared_strains_frac, np.linspace(0, 1, num=21))
        ],
        ax_row
    ):
        same_vals = shared_strains[same_masking]
        diff_vals = shared_strains[diff_masking]
        ax.hist(same_vals, bins=bins, density=True, alpha=0.5, label=f'same {var_title}')
        ax.hist(diff_vals, bins=bins, density=True, alpha=0.5, label=f'diff {var_title}')
        ax.legend()
        ax.set_title((var_title, val_title))
        ax.set_yscale('log')
        mwu_pvalue = sp.stats.mannwhitneyu(same_vals, diff_vals, alternative='greater', nan_policy='omit').pvalue
        ax.annotate(f'{mwu_pvalue:0.2e}', xy=(0.5, 0.5), xycoords='axes fraction')
        print(var_title, val_title, np.round(np.mean(same_vals[~np.isnan(same_vals)]), 3), np.round(np.mean(diff_vals[~np.isnan(diff_vals)]), 3), mwu_pvalue)

In [None]:
cluster_analysis_sample_list = d.index

In [None]:
strain_depth = pd.read_table('data/hmp2.a.r.proc.gtpro.filt-poly05-cvrg05.fit-sfacts9-s75-g10000-seed0.collapse-10.strain_depth.tsv', index_col=['sample', 'strain']).squeeze().unstack(fill_value=0)
strain_depth = strain_depth.loc[cluster_analysis_sample_list]

In [None]:
thresh = 1e-1

sfacts_jacc = pdist((strain_depth > thresh), metric='jaccard')

In [None]:
sfacts_num = pdist((strain_depth > thresh), metric=lambda x, y: (x & y).sum())

In [None]:
fig, axs = plt.subplots(5, 4, figsize=(15, 15))
for (var_title, diff_masking, same_masking), ax_row in zip(
    [
        ('subj', diff_subject_mask, ~diff_subject_mask),
        ('site', diff_site_mask & diff_subject_mask, ~diff_site_mask & diff_subject_mask),
        ('stool', diff_stool_mask & ~diff_subject_mask, ~diff_stool_mask & ~diff_subject_mask),
        ('diseased', diff_has_ibd_mask & diff_subject_mask, ~diff_has_ibd_mask & diff_subject_mask),
        ('form', diff_diagnosis_mask & diff_subject_mask, ~diff_diagnosis_mask & diff_subject_mask),
    ],
    axs
):
    for (val_title, shared_strains, bins), ax in zip(
        [
            ('num', shared_strains_num, np.arange(50)),
            ('frac', shared_strains_frac, np.linspace(0, 1, num=21)),
            ('sf_jacc', 1 - sfacts_jacc, np.linspace(0, 1, num=21)),
            ('sf_num', sfacts_num, np.arange(50)),
        ],
        ax_row
    ):
        same_vals = shared_strains[same_masking]
        diff_vals = shared_strains[diff_masking]
        ax.hist(same_vals, bins=bins, density=True, alpha=0.5, label=f'same {var_title}')
        ax.hist(diff_vals, bins=bins, density=True, alpha=0.5, label=f'diff {var_title}')
        ax.legend()
        ax.set_title((var_title, val_title))
        ax.set_yscale('log')
        mwu_pvalue = sp.stats.mannwhitneyu(same_vals, diff_vals, alternative='greater', nan_policy='omit').pvalue
        ax.annotate(f'{mwu_pvalue:0.2e}', xy=(0.5, 0.5), xycoords='axes fraction')
        print(var_title, val_title, np.round(np.mean(same_vals[~np.isnan(same_vals)]), 3), np.round(np.mean(diff_vals[~np.isnan(diff_vals)]), 3), mwu_pvalue)