# Preamble

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

## Imports

In [None]:
import os
import subprocess
import time
from itertools import chain, product
from tempfile import mkstemp
from warnings import filterwarnings

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import sfacts as sf
import statsmodels.formula.api as smf
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable

# from fastcluster import linkage
from scipy.cluster.hierarchy import linkage
from scipy.spatial.distance import pdist, squareform, cdist
from statsmodels.graphics.regressionplots import influence_plot
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
import lib.thisproject.data
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

In [None]:
def linkage_order(linkage, labels):
    return labels[sp.cluster.hierarchy.to_tree(linkage).pre_order(lambda x: x.id)]

def label_een_experiment_sample(x):
    if x.sample_type == "human":
        label = f"[{x.name}] {x.collection_date_relative_een_end} {x.diet_or_media}"
    elif x.sample_type in ["Fermenter_inoculum"]:
        label = f"[{x.name}] {x.source_samples} inoc {x.diet_or_media}"
    elif x.sample_type in ["Fermenter"]:
        # label = f"[{x.name}] {x.source_samples} frmnt {x.diet_or_media}"
        label = f"[{x.name}] {x.source_samples} {x.diet_or_media}"
    elif x.sample_type in ["mouse"]:
        if x.status_mouse_inflamed == 'Inflamed':
            # label = f"[{x.name}] {x.source_samples} 🐭 {x.mouse_genotype} {x.diet_or_media} inflam"
            label = f"[{x.name}] {x.source_samples} {x.diet_or_media} inflam"
        elif x.status_mouse_inflamed == 'not_Inflamed':
            # label = f"[{x.name}] {x.source_samples} 🐭 {x.mouse_genotype} {x.diet_or_media} not_inf"
            label = f"[{x.name}] {x.source_samples} {x.diet_or_media} not_inf"
        else:
            raise ValueError(f"sample type {x.status_mouse_inflamed} not understood")
    else:
        raise ValueError(f"sample type {x.sample_type} not understood")
    return label

In [None]:
sns.set_context("talk")
plt.rcParams["figure.dpi"] = 150

# Prepare Metadata

In [None]:
subject_order = [
    "A",
    "B",
    "H",
    "C",
    "D",
    "E",
    "F",
    "G",
    "K",
    "L",
    "M",
    "N",
    "O",
    "P",
    "Q",
    "R",
    "S",
    "T",
    "U",
]

In [None]:
sample = (
    pd.read_table("meta/een-mgen/sample.tsv")
    .assign(
        label=lambda x: x[
            ["collection_date_relative_een_end", "diet_or_media", "sample_id"]
        ].apply(tuple, axis=1)
    )
    .set_index("sample_id")
    .assign(full_label=lambda d: d.apply(label_een_experiment_sample, axis=1))
)

# Prepare Data

In [None]:
rotu_counts = pd.read_table(
    "data/group/een/a.proc.zotu_counts.tsv", index_col="#OTU ID"
).rename_axis(index="zotu", columns="sample_id")
rotu_taxonomy = rotu_counts.taxonomy
rotu_counts = rotu_counts.drop(columns=["taxonomy"]).T
rotu_rabund = rotu_counts.divide(rotu_counts.sum(1), axis=0)

sample_rotu_bc_linkage = sp.cluster.hierarchy.linkage(
    rotu_rabund, method="average", metric="braycurtis", optimal_ordering=True
)

In [None]:
motu_depth = (pd.read_table(
    "data/group/een/r.proc.gene99_new-v22-agg75.spgc_specgene-ref-t25-p95.species_depth.tsv",
    names=['sample', "species_id", 'depth'], index_col=['sample', "species_id"],
    )
    .depth.unstack(fill_value=0)
    .rename(columns=str, index=lambda x: "CF_" + str(int(x.split("_")[1])))
    .rename({'CF_15': 'CF_11', 'CF_11': 'CF_15'})  # Sample swap
)
motu_rabund = motu_depth.divide(motu_depth.sum(1), axis=0)

motu_rabund

In [None]:
sotu_depth = []
missing_files = []
for species_id in motu_depth.columns:
    path = f"data/group/een/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.comm.tsv"
    try:
        d = (
            pd.read_table(path, index_col=["sample", "strain"])
            .squeeze()
            .unstack()
            .rename(columns=str, index=lambda x: "CF_" + str(int(x.split("_")[1])))
            .rename({'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap.
        )
    except FileNotFoundError:
        missing_files.append(path)
        d = pd.DataFrame([])
    _keep_strains = idxwhere(d.sum() > 0.05)
    assert d.index.isin(motu_depth.index).all()
    d = d.reindex(index=motu_depth.index, columns=_keep_strains, fill_value=0)
    d = d.assign(__other=lambda x: 1 - x.sum(1)).rename(columns={"__other": -1})
    d[d < 0] = 0
    d = d.divide(d.sum(1), axis=0)
    d = d.multiply(motu_depth[species_id], axis=0)
    d = d.rename(columns=lambda s: f"{species_id}_{s}")
    sotu_depth.append(d)
sotu_depth = pd.concat(sotu_depth, axis=1)
sotu_rabund = sotu_depth.divide(sotu_depth.sum(1), axis=0)
len(motu_depth.columns), len(missing_files)

In [None]:
def parse_taxonomy_string(taxonomy_string):
    values = taxonomy_string.split(";")
    return pd.Series(values, index=["d__", "p__", "c__", "o__", "f__", "g__", "s__"])

In [None]:
motu_taxonomy_inpath = "ref/uhgg_genomes_all_v2.tsv"

_motu_taxonomy = (
    pd.read_table(motu_taxonomy_inpath)[lambda x: x.Genome == x.Species_rep]
    .assign(species_id=lambda x: "1" + x.MGnify_accession.str.split("-").str[2])
    .set_index("species_id")
)

# motu_lineage_string = _motu_taxonomy.Lineage

motu_taxonomy = _motu_taxonomy.Lineage.apply(
    parse_taxonomy_string
)  # .assign(taxonomy_string=motu_lineage_string)
motu_taxonomy

# Species Enrichment Analysis

In [None]:
def enrichment_test(d):
    try:
        res = sp.stats.wilcoxon(d["EEN"], d["PostEEN"])
    except ValueError:
        res = (np.nan, np.nan)
    log2_ratio = np.log2(d["PostEEN"] / d["EEN"])
    return pd.Series(
        [log2_ratio.mean(), d["EEN"].mean(), d["PostEEN"].mean(), res[1]],
        index=["log2_ratio", "mean_EEN", "mean_PostEEN", "pvalue"],
    )

In [None]:
motu_enrichment_results = (
    motu_rabund.apply(lambda x: x + x.replace({0: np.inf}).min())
    .join(sample[["subject_id", "diet_or_media"]])
    .groupby(["subject_id", "diet_or_media"])
    .mean()
    .stack()
    .unstack("diet_or_media")[["EEN", "PostEEN"]]
    .dropna()
    .assign(log2_ratio=lambda x: np.log2(x["PostEEN"] / x["EEN"]))
    .rename_axis(index=["subject_id", "motu_id"])
    .groupby(level="motu_id")
    .apply(enrichment_test)
)
motu_enrichment_results.sort_values("pvalue", ascending=True).head(20)
# fig, ax = plt.subplots()
# print(d.log2_ratio.mean())
# print(sp.stats.wilcoxon(d['PostEEN'], d['EEN']))
# ax.hist(d.log2_ratio, bins=20)

In [None]:
motu_mean_rabund = (motu_rabund.apply(lambda x: x + x.replace({0: np.inf}).min())
    .join(sample[["subject_id"]])
    .groupby(["subject_id"]).mean()).mean(0)
motu_mean_rabund

In [None]:
def pair_classifier(sample_typeA, sample_typeB):
    return ":".join(sorted(set([sample_typeA, sample_typeB])))


def construct_turnover_analysis_data(
    dmat,
    meta,
    sample_type_var,
    stratum_var=None,
    time_var=None,
):
    var_list = []
    for var in [sample_type_var, stratum_var, time_var]:
        if var is not None:
            var_list.append(var)
    meta = meta.reindex(dmat.index)[var_list].dropna()
    dmat = dmat.loc[meta.index, meta.index]
    data = []
    for (i, idxA), (j, idxB) in product(enumerate(meta.index), repeat=2):
        pair_data = {
            "sampleA": idxA,
            "sampleB": idxB,
            "sample_typeA": meta.loc[idxA, sample_type_var],
            "sample_typeB": meta.loc[idxB, sample_type_var],
            "diss": dmat.loc[idxA, idxB],
        }
        if stratum_var is not None:
            pair_data.update(
                {
                    "stratumA": meta.loc[idxA, stratum_var],
                    "stratumB": meta.loc[idxB, stratum_var],
                }
            )
        if time_var is not None:
            pair_data.update(
                {"timeA": meta.loc[idxA, time_var], "timeB": meta.loc[idxB, time_var]}
            )
        data.append(pair_data)
    data = pd.DataFrame(
        data,
    )
    data = data.assign(
        pair_type=lambda x: x.apply(
            lambda y: pair_classifier(y.sample_typeA, y.sample_typeB), axis=1
        )
    )

    if time_var:
        data = data.assign(time_delta=lambda x: np.abs(x.timeB - x.timeA))

    data = data[lambda x: (x.stratumA == x.stratumB) & (x.sampleA < x.sampleB)]
    if stratum_var:
        data = data.assign(stratum=lambda x: x.stratumA)

    return data

In [None]:
species_id = '101378'
inpath = f"data/group/een/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.world.nc"
sf_fit = (
    sf.data.World.load(
        inpath
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={"CF_11": "CF_15", "CF_15": "CF_11"})
    # .drop_low_abundance_strains(0.01)
    # .rename_coords(strain=str)
)

comm = sf_fit.community.to_pandas()
comm_bc_pdist = pd.DataFrame(squareform(pdist(comm, metric='braycurtis')), index=comm.index, columns=comm.index)

turnover_data = construct_turnover_analysis_data(
    comm_bc_pdist,
    meta=sample[lambda x: x.diet_or_media.isin(['EEN', 'PostEEN'])],
    sample_type_var="diet_or_media",
    stratum_var="subject_id",
    time_var="collection_date_relative_een_end",
)

formula = "diss ~ 0 + pair_type + cr(time_delta, 4) + C(stratum, Sum)"
fit = smf.ols(formula, data=turnover_data).fit()

coef_list = []
for coef in ['pair_type[EEN]', 'pair_type[PostEEN]', 'pair_type[EEN:PostEEN]']:
    if coef in fit.params:
        coef_list.append(fit.params[coef])
    else:
        coef_list.append(np.nan)        

for pair_type, d1 in turnover_data.groupby('pair_type'):
    plt.scatter('time_delta', 'diss', data=d1, label=pair_type)
plt.xscale('log')
plt.legend()
fit.summary()


In [None]:
lm_coef_results = []
all_pairs = []

for species_id in motu_enrichment_results.index:
    inpath = f"data/group/een/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.world.nc"
    if not os.path.exists(inpath):
        print(f"No sfacts fit for {species_id}.")
        continue
    sf_fit = (
        sf.data.World.load(
            inpath
        )
        .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
        .rename_coords(sample={"CF_11": "CF_15", "CF_15": "CF_11"})
        # .drop_low_abundance_strains(0.01)
        # .rename_coords(strain=str)
    )
    
    comm = sf_fit.community.to_pandas()
    comm_bc_pdist = pd.DataFrame(squareform(pdist(comm, metric='braycurtis')), index=comm.index, columns=comm.index)
    
    turnover_data = construct_turnover_analysis_data(
        comm_bc_pdist,
        meta=sample[lambda x: x.diet_or_media.isin(['EEN', 'PostEEN'])],
        sample_type_var="diet_or_media",
        stratum_var="subject_id",
        time_var="collection_date_relative_een_end",
    )
    all_pairs.append(turnover_data.assign(species_id=species_id))
    if turnover_data.empty:
        print(f"No data for {species_id}.")
        continue
    
    formula = "diss ~ 0 + pair_type + cr(time_delta, 4) + C(stratum, Sum)"
    try:
        fit = smf.ols(formula, data=turnover_data).fit()
    except ZeroDivisionError:
        print(f"OLS failed for {species_id}.")
        continue
    except ValueError:
        print(f"OLS failed for {species_id}.")
        continue

    coef_list = []
    for coef in ['pair_type[EEN]', 'pair_type[PostEEN]', 'pair_type[EEN:PostEEN]']:
        if coef in fit.params:
            coef_list.append(fit.params[coef])
        else:
            coef_list.append(np.nan)        
    lm_coef_results.append((species_id, *coef_list, turnover_data.shape[0], turnover_data.groupby('stratum').diss.mean().mean()))
    
    # for pair_type, d1 in turnover_data.groupby('pair_type'):
    #     plt.scatter('time_delta', 'diss', data=d1, label=pair_type)
    # plt.xscale('log')
    # plt.legend()
    # fit.summary()
lm_coef_results = pd.DataFrame(lm_coef_results, columns=['motu_id', 'EEN', 'PostEEN', 'Transition', 'num_pairs', 'overall_mean_diss'])
all_pairs = pd.concat(all_pairs)

In [None]:
d

In [None]:
d = motu_enrichment_results.assign(overall_mean=motu_mean_rabund).dropna(subset=['overall_mean']).join(lm_coef_results.set_index('motu_id')).assign(
    # relative_transition=lambda x: x['Transition'] - x['EEN'],
    # signif_enrich=lambda w: w.pvalue < 0.05,
    indicator_score=lambda x: x.log2_ratio * x.Transition**2 * x.mean_PostEEN * x.mean_EEN,
)[lambda x: x.num_pairs > 20].dropna(subset=['log2_ratio', 'Transition'])

plt.scatter('log2_ratio', 'overall_mean_diss', data=d.loc[['101493']], edgecolor='r', facecolor='none', s=200)
plt.scatter('log2_ratio', 'overall_mean_diss', data=d, c='overall_mean', norm=mpl.colors.LogNorm())
# plt.ylim(-0.5, 1.5)
plt.colorbar(label='mean relative abundance')
plt.xlabel('log2_fold_change')
plt.ylabel('mean turn-over')
# plt.yscale('symlog', linthresh=1e-2)
# plt.xscale('log')

print(sp.stats.spearmanr(d.log2_ratio, d.overall_mean_diss))
d.sort_values('overall_mean_diss', ascending=False).head(50).join(motu_taxonomy)

In [None]:
d[lambda x: (x.log2_ratio > 0) & (x.log2_ratio < 2.5) & (x.overall_mean_diss < 0.2) & (x.overall_mean > 1e-2)].sort_values('overall_mean_diss')

In [None]:
d = motu_enrichment_results.join(lm_coef_results.set_index('motu_id')).assign(
    # relative_transition=lambda x: x['Transition'] - x['EEN'],
    # signif_enrich=lambda w: w.pvalue < 0.05,
).dropna(subset=['log2_ratio', 'Transition'])

plt.scatter('log2_ratio', 'Transition', data=d, c='mean_EEN', norm=mpl.colors.LogNorm())
plt.ylim(-0.5, 1.5)
plt.colorbar()
# plt.yscale('symlog', linthresh=1e-2)
# plt.xscale('log')

sp.stats.spearmanr(d.log2_ratio, d.Transition)

In [None]:
d = motu_enrichment_results.join(lm_coef_results.set_index('motu_id')).assign(
    relative_transition=lambda x: x['Transition'] - x['EEN'],
    signif_enrich=lambda w: w.pvalue < 0.05,
).dropna(subset=['relative_transition'])

plt.scatter('mean_EEN', 'relative_transition', data=d, c='signif_enrich')
plt.yscale('symlog', linthresh=1e-2)
plt.xscale('log')

In [None]:
d[lambda x: (x.mean_EEN > 1e-3) & (x.relative_transition > 1e-1)]

In [None]:
d[lambda x: (x.mean_EEN > 1e-3)].sort_values('relative_transition', ascending=False).head(20)

In [None]:
sp.stats.spearmanr(d['log2_ratio'], d['relative_transition'])

In [None]:
d.sort_values('relative_transition', ascending=False).head(50)

In [None]:
from matplotlib_venn import venn3


detect_thresh = 1e-4
motu_detect_subject_sample_type = (motu_rabund > detect_thresh).join(sample[['subject_id', 'sample_type']]).groupby(['subject_id', 'sample_type']).any()
sotu_detect_subject_sample_type = (sotu_rabund > detect_thresh).join(sample[['subject_id', 'sample_type']]).groupby(['subject_id', 'sample_type']).any()
rotu_detect_subject_sample_type = (rotu_rabund > detect_thresh).join(sample[['subject_id', 'sample_type']]).groupby(['subject_id', 'sample_type']).any()


subject_order = ['A', 'B', 'H']
taxon_presence_data = dict(species=motu_detect_subject_sample_type)

m_levels = len(taxon_presence_data)
n_subjects = len(subject_order)

fig, axs = plt.subplots(m_levels, n_subjects, figsize=(5 * n_subjects, 5 * m_levels), squeeze=False)

shared_strains = {}
for subject_id, ax_col in zip(subject_order, axs.T):
    print(f"Subject {subject_id}")
    print(f"--------")
    for (taxon_level, tax_detect_subject_sample_type), ax in zip(taxon_presence_data.items(), ax_col):
        print(taxon_level, "num detected", sep='\t')
        d = tax_detect_subject_sample_type.loc[subject_id].T
        print(d['human'].sum(), f"across all human samples", sep='\t')
        print(d['Fermenter'].sum(), f"across all fermenter samples", sep='\t')
        print(d['mouse'].sum(), f"across all mouse samples", sep='\t')
        print((d['human'] & ~(d['Fermenter'] | d['mouse'])).sum(), f"in humans and NOT mouse or fermenter", sep='\t')
        print(((d['Fermenter'] | d['mouse']) & ~d['human']).sum(), f"in mouse or fermenter and NOT human", sep='\t')
        print((d['mouse'] & ~d['Fermenter']).sum(), f"in mouse and NOT fermenter", sep='\t')
        print((d['Fermenter'] & ~d['mouse']).sum(), f"in fermenter and NOT mouse", sep='\t')
        print()
        venn3([set(idxwhere(d['human'])), set(idxwhere(d['Fermenter'])), set(idxwhere(d['mouse']))], ax=ax, set_labels=('h', 'f', 'm'))
        ax.set_title(f"Subject {subject_id}")
        shared_strains[subject_id] = (set(idxwhere(d['human'])) & set(idxwhere(d['Fermenter'])) & set(idxwhere(d['mouse'])))
    print()


In [None]:
shared_strains['A'] & shared_strains['B'] & shared_strains['H']

In [None]:
from matplotlib_venn import venn3


detect_thresh = 1e-4
motu_detect_subject_sample_type = (motu_rabund > detect_thresh).join(sample[['subject_id', 'sample_type']]).groupby(['subject_id', 'sample_type']).any()
sotu_detect_subject_sample_type = (sotu_rabund > detect_thresh).join(sample[['subject_id', 'sample_type']]).groupby(['subject_id', 'sample_type']).any()
rotu_detect_subject_sample_type = (rotu_rabund > detect_thresh).join(sample[['subject_id', 'sample_type']]).groupby(['subject_id', 'sample_type']).any()


subject_order = ['A', 'B', 'H']
taxon_presence_data = dict(strains=sotu_detect_subject_sample_type)

m_levels = len(taxon_presence_data)
n_subjects = len(subject_order)

fig, axs = plt.subplots(m_levels, n_subjects, figsize=(5 * n_subjects, 5 * m_levels), squeeze=False)

shared_strains = {}
for subject_id, ax_col in zip(subject_order, axs.T):
    print(f"Subject {subject_id}")
    print(f"--------")
    for (taxon_level, tax_detect_subject_sample_type), ax in zip(taxon_presence_data.items(), ax_col):
        print(taxon_level, "num detected", sep='\t')
        d = tax_detect_subject_sample_type.loc[subject_id].T
        print(d['human'].sum(), f"across all human samples", sep='\t')
        print(d['Fermenter'].sum(), f"across all fermenter samples", sep='\t')
        print(d['mouse'].sum(), f"across all mouse samples", sep='\t')
        print((d['human'] & ~(d['Fermenter'] | d['mouse'])).sum(), f"in humans and NOT mouse or fermenter", sep='\t')
        print(((d['Fermenter'] | d['mouse']) & ~d['human']).sum(), f"in mouse or fermenter and NOT human", sep='\t')
        print((d['mouse'] & ~d['Fermenter']).sum(), f"in mouse and NOT fermenter", sep='\t')
        print((d['Fermenter'] & ~d['mouse']).sum(), f"in fermenter and NOT mouse", sep='\t')
        print()
        venn3([set(idxwhere(d['human'])), set(idxwhere(d['Fermenter'])), set(idxwhere(d['mouse']))], ax=ax, set_labels=('h', 'f', 'm'))
        ax.set_title(f"Subject {subject_id}")
        shared_strains[subject_id] = (set(idxwhere(d['human'])) & set(idxwhere(d['Fermenter'])) & set(idxwhere(d['mouse'])))
    print()


In [None]:
shared_strains

In [None]:
from matplotlib_venn import venn3


detect_thresh = 1e-3
motu_detect_subject_sample_type = (motu_rabund > detect_thresh).join(sample[['subject_id', 'sample_type']]).groupby(['subject_id', 'sample_type']).any()
sotu_detect_subject_sample_type = (sotu_rabund > detect_thresh).join(sample[['subject_id', 'sample_type']]).groupby(['subject_id', 'sample_type']).any()
rotu_detect_subject_sample_type = (rotu_rabund > detect_thresh).join(sample[['subject_id', 'sample_type']]).groupby(['subject_id', 'sample_type']).any()


subject_order = ['A', 'B', 'H']
taxon_presence_data = dict(zotus=rotu_detect_subject_sample_type, species=motu_detect_subject_sample_type, strains=sotu_detect_subject_sample_type)

m_levels = len(taxon_presence_data)
n_subjects = len(subject_order)

fig, axs = plt.subplots(n_subjects, m_levels, figsize=(5 * m_levels, 5 * n_subjects))

for subject_id, ax_row in zip(subject_order, axs):
    print(f"Subject {subject_id}")
    print(f"--------")
    for (taxon_level, tax_detect_subject_sample_type), ax in zip(taxon_presence_data.items(), ax_row):
        print(taxon_level, "num detected", sep='\t')
        d = tax_detect_subject_sample_type.loc[subject_id].T
        print(d['human'].sum(), f"across all human samples", sep='\t')
        print(d['Fermenter'].sum(), f"across all fermenter samples", sep='\t')
        print(d['mouse'].sum(), f"across all mouse samples", sep='\t')
        print((d['human'] & ~(d['Fermenter'] | d['mouse'])).sum(), f"in humans and NOT mouse or fermenter", sep='\t')
        print(((d['Fermenter'] | d['mouse']) & ~d['human']).sum(), f"in mouse or fermenter and NOT human", sep='\t')
        print((d['mouse'] & ~d['Fermenter']).sum(), f"in mouse and NOT fermenter", sep='\t')
        print((d['Fermenter'] & ~d['mouse']).sum(), f"in fermenter and NOT mouse", sep='\t')
        print()
        venn3([set(idxwhere(d['human'])), set(idxwhere(d['Fermenter'])), set(idxwhere(d['mouse']))], ax=ax, set_labels=('h', 'f', 'm'))
        ax.set_title((taxon_level, subject_id))
    print()


In [None]:
speci

In [None]:
from matplotlib_venn import venn3


detect_thresh = 1e-4
motu_detect_subject_sample_type = (motu_rabund > detect_thresh).join(sample[['subject_id', 'sample_type']]).groupby(['sample_type', 'subject_id']).any()
sotu_detect_subject_sample_type = (sotu_rabund > detect_thresh).join(sample[['subject_id', 'sample_type']]).groupby(['sample_type', 'subject_id']).any()
rotu_detect_subject_sample_type = (rotu_rabund > detect_thresh).join(sample[['subject_id', 'sample_type']]).groupby(['sample_type', 'subject_id']).any()


sample_type_order = ['human', 'Fermenter', 'mouse']
taxon_presence_data = dict(zotus=rotu_detect_subject_sample_type, species=motu_detect_subject_sample_type, strains=sotu_detect_subject_sample_type)

m_levels = len(taxon_presence_data)
n_types = len(sample_type_order)

fig, axs = plt.subplots(n_types, m_levels, figsize=(5 * m_levels, 5 * n_types))

for sample_type, ax_row in zip(sample_type_order, axs):
    for (taxon_level, tax_detect_subject_sample_type), ax in zip(taxon_presence_data.items(), ax_row):
        d = tax_detect_subject_sample_type.loc[sample_type].T
        venn3([set(idxwhere(d['A'])), set(idxwhere(d['B'])), set(idxwhere(d['H']))], ax=ax, set_labels=('A', 'B', 'H'))
        ax.set_title((taxon_level, sample_type))
    print()


In [None]:
species_taxonomy

In [None]:
d1

In [None]:
detect_thresh = 1e-4
motu_detect_subject_sample_type = (motu_rabund > detect_thresh).join(sample[['subject_id', 'sample_type']]).groupby(['subject_id', 'sample_type']).any()
sotu_detect_subject_sample_type = (sotu_rabund > detect_thresh).join(sample[['subject_id', 'sample_type']]).groupby(['subject_id', 'sample_type']).any()
rotu_detect_subject_sample_type = (rotu_rabund > detect_thresh).join(sample[['subject_id', 'sample_type']]).groupby(['subject_id', 'sample_type']).any()


sample_type_order = ['human', 'Fermenter', 'mouse']
taxon_presence_data = dict(zotus=rotu_detect_subject_sample_type, species=motu_detect_subject_sample_type, strains=sotu_detect_subject_sample_type)
# taxon_presence_data = dict(species=motu_detect_subject_sample_type)


m_levels = len(taxon_presence_data)
n_types = len(sample_type_order)

fig, axs = plt.subplots(n_types, m_levels, figsize=(5 * m_levels, 5 * n_types), squeeze=False)


# taxon_level = 'species'
# d0 = taxon_presence_data[taxon_level].loc[['A', 'B', 'H']]
# shared_species_list = idxwhere(d.rename_axis(columns='taxon').unstack('subject_id').T.human.unstack().all(1))
# d1 = d0[shared_species_list]

# fig, ax = plt.subplots()
for (taxon_level, tax_detect_subject_sample_type), ax_col in zip(taxon_presence_data.items(), axs.T):
    d0 = taxon_presence_data[taxon_level].loc[['A', 'B', 'H']]
    human_shared_taxa_list = idxwhere(d0.rename_axis(columns='taxon').unstack('subject_id').T.human.unstack().all(1))
    d1 = d0[human_shared_taxa_list]
    for sample_type, ax in zip(sample_type_order, ax_col):
        d2 = d1.xs(sample_type, level='sample_type').apply(lambda x: set(idxwhere(x)), axis=1).to_dict()
        venn3([d2['A'], d2['B'], d2['H']], ax=ax, set_labels=['A', 'B', 'H'])
        ax.set_title((sample_type, taxon_level))

In [None]:
all_pairs.groupby(['species_id', 'pair_type', 'stratum']).apply(lambda x: x.diss.mean()).groupby(['species_id', 'pair_type']).mean().unstack('pair_type').sort_values(['EEN:PostEEN'], ascending=False).head(20)