## Preamble

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

### Imports

In [None]:
import os
import subprocess
import time
from itertools import chain, product
from tempfile import mkstemp
from warnings import filterwarnings

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import sfacts as sf
import statsmodels.formula.api as smf
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable

# from fastcluster import linkage
from scipy.cluster.hierarchy import linkage
from scipy.spatial.distance import pdist, squareform
from statsmodels.graphics.regressionplots import influence_plot
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
import lib.thisproject.data
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

In [None]:
sns.set_context("paper")
plt.rcParams["figure.dpi"] = 200

In [None]:
def _calculate_2tailed_pvalue_from_perm(obs, perms):
    hypoth_left = perms > obs
    hypoth_right = perms < obs
    null_p_left = (hypoth_left.sum() + 1) / (len(hypoth_left) + 1)
    null_p_right = (hypoth_right.sum() + 1) / (len(hypoth_right) + 1)
    return np.minimum(null_p_left, null_p_right) * 2

In [None]:
def linkage_order(linkage, labels):
    return labels[sp.cluster.hierarchy.to_tree(linkage).pre_order(lambda x: x.id)]


def is_prime(n):
    if n <= 1:
        return False
    for i in range(2, int(n**0.5) + 1):
        if n % i == 0:
            return False
    return True


def iterate_primes_up_to(n, return_index=False):
    n = int(np.ceil(n))
    idx = 0
    for i in range(n):
        if is_prime(i):
            if return_index:
                yield (idx, i)
            else:
                yield i
            idx += 1


def maximally_shuffled_order(sorted_order):
    n = len(sorted_order)
    primes_list = list(iterate_primes_up_to(np.sqrt(n)))
    table = pd.DataFrame(np.arange(n), index=sorted_order, columns=["original_order"])
    for prime in primes_list:
        table[prime] = table.original_order % prime
    table.sort_values(primes_list).original_order.values
    table = table.assign(new_order=table.sort_values(primes_list).original_order.values)
    z = table.sort_values("new_order").original_order.values
    table["delta"] = [np.nan] + list(z[1:] - z[:-1])
    return table.sort_values("new_order").index.to_list()

## Construct Metadata

In [None]:
pair_type_palette={'Transition': 'plum', 'EEN': 'pink', 'PostEEN': 'lightblue'}

diet_palette = {
    "EEN": "lightgreen",
    "PostEEN": "lightblue",
    "InVitro": "plum",
    "PreEEN": "lightpink",
}

subject_order = [
    "A",
    "B",
    "H",
    "C",
    "D",
    "E",
    "F",
    "G",
    "K",
    "L",
    "M",
    "N",
    "O",
    "P",
    "Q",
    "R",
    "S",
    "T",
    "U",
]

# NOTE: Requires a dummy value because I want exactly 20 items.
subject_palette = lib.plot.construct_ordered_palette(
    subject_order + [f"dummy{i}" for i in range(20 - len(subject_order))], cm="tab20"
)
subject_palette["X"] = "black"
pair_type_order = ["EEN", "Transition", "PostEEN"]
pair_type_marker_palette = {"EEN": "s", "Transition": ">", "PostEEN": "o"}
pair_type_linestyle_palette = {"EEN": ":", "Transition": "-.", "PostEEN": "-"}

In [None]:
sample = (
    pd.read_table("meta/een-mgen/sample.tsv")
    .assign(
        label=lambda x: x[
            ["collection_date_relative_een_end", "diet_or_media", "sample_id"]
        ].apply(tuple, axis=1)
    )
    .set_index("sample_id")
)
subject = pd.read_table("meta/een-mgen/subject.tsv", index_col="subject_id")

In [None]:
rotu_counts = pd.read_table(
    "data/group/een/a.proc.zotu_counts.tsv", index_col="#OTU ID"
).rename_axis(index="zotu", columns="sample_id")
rotu_taxonomy = rotu_counts.taxonomy
rotu_counts = rotu_counts.drop(columns=["taxonomy"]).T
rotu_rabund = rotu_counts.divide(rotu_counts.sum(1), axis=0)

sample_rotu_bc_linkage = sp.cluster.hierarchy.linkage(
    rotu_rabund, method="average", metric="braycurtis", optimal_ordering=True
)

In [None]:
missing_samples = sorted(idxwhere(~rotu_counts.index.to_series().isin(sample.index)))
print(len(missing_samples), ", ".join(missing_samples))

In [None]:
x = rotu_rabund
row_colors = pd.DataFrame(
    dict(
        subj=sample.subject_id.map(subject_palette),
        swap=sample.index.to_series()
        .isin(["CF_11", "CF_15"])
        .replace({False: "grey", True: "black"}),
    )
)
row_linkage = sample_rotu_bc_linkage

sns.clustermap(
    rotu_rabund,
    norm=mpl.colors.PowerNorm(1 / 5),
    row_colors=row_colors,
    row_linkage=row_linkage,
)

In [None]:
pd.read_table(
    "data/group/een/r.proc.gtpro.species_depth.tsv")

In [None]:
gtpro_depth = (pd.read_table(
    "data/group/een/r.proc.gtpro.species_depth.tsv",
    index_col=['sample', "species_id"],
    )
    .depth.unstack(fill_value=0)
    .rename(columns=str, index=lambda x: "CF_" + str(int(x.split("_")[1])))
    .rename({'CF_15': 'CF_11', 'CF_11': 'CF_15'})  # Sample swap
)
gtpro_rabund = gtpro_depth.divide(gtpro_depth.sum(1), axis=0)

gtpro_rabund

In [None]:
motu_depth = (pd.read_table(
    "data/group/een/r.proc.gene99_new-v22-agg75.spgc_specgene-ref-t25-p95.species_depth.tsv",
    names=['sample', "species_id", 'depth'], index_col=['sample', "species_id"],
    )
    .depth.unstack(fill_value=0)
    .rename(columns=str, index=lambda x: "CF_" + str(int(x.split("_")[1])))
    .rename({'CF_15': 'CF_11', 'CF_11': 'CF_15'})  # Sample swap
)
motu_rabund = motu_depth.divide(motu_depth.sum(1), axis=0)

motu_rabund

In [None]:
x, y = align_indexes(motu_rabund, rotu_rabund)


x_linkage = linkage(x, method="average", metric="braycurtis", optimal_ordering=True)
y_linkage = linkage(y, method="average", metric="braycurtis", optimal_ordering=True)
colors = pd.DataFrame(
    dict(
        subj=sample.subject_id.map(subject_palette),
        swap=sample.index.to_series()
        .isin(["CF_11", "CF_15"])
        .replace({False: "grey", True: "black"}),
    )
)

x_pdist = pd.DataFrame(
    squareform(pdist(x, metric="braycurtis")), index=x.index, columns=x.index
)
sns.clustermap(
    x_pdist,
    row_linkage=y_linkage,
    col_linkage=x_linkage,
    row_colors=colors,
    col_colors=colors,
)

In [None]:
x, y = align_indexes(motu_rabund, gtpro_rabund)


x_linkage = linkage(x, method="average", metric="braycurtis", optimal_ordering=True)
y_linkage = linkage(y, method="average", metric="braycurtis", optimal_ordering=True)
colors = pd.DataFrame(
    dict(
        subj=sample.subject_id.map(subject_palette),
        swap=sample.index.to_series()
        .isin(["CF_11", "CF_15"])
        .replace({False: "grey", True: "black"}),
    )
)

x_pdist = pd.DataFrame(
    squareform(pdist(x, metric="braycurtis")), index=x.index, columns=x.index
)
sns.clustermap(
    x_pdist,
    row_linkage=y_linkage,
    col_linkage=x_linkage,
    row_colors=colors,
    col_colors=colors,
)

In [None]:
bins = np.linspace(0, 30_000, num=200)

fig, axs = plt.subplots(2, sharex=True)

for (title, x), ax in zip(
    dict(
        total_depth_by_sample=motu_depth.sum(1),
        total_depth_by_species=motu_depth.sum(0),
    ).items(),
    axs.flatten(),
):
    ax.hist(x, bins=np.logspace(-1, 5, num=100))
    ax.set_title(title)
    ax.set_xscale("log")
fig.tight_layout()

In [None]:
motu_rabund.mean().sort_values(ascending=False).head(20)

In [None]:
n_species = 10
top_motus = (
    (motu_rabund > 1e-5).sum().sort_values(ascending=False).head(n_species).index
)

fig, axs = plt.subplots(
    n_species, figsize=(5, 0.3 * n_species), sharex=True, sharey=True
)

bins = np.logspace(-8, 1, num=51)

for species_id, ax in zip(top_motus, axs):
    # ax.hist(rabund_subset[species_id], bins=bins, alpha=0.7)
    ax.hist(motu_rabund[species_id], bins=bins, alpha=0.7)
    ax.set_xscale("log")
    prevalence = (motu_rabund[species_id] > 1e-5).mean()
    ax.set_title("")
    # ax.set_xticks()
    # ax.set_yticks()
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.patch.set_alpha(0.0)
    for spine in ["left", "right", "top", "bottom"]:
        ax.spines[spine].set_visible(False)
    ax.annotate(
        f"{species_id} ({prevalence:0.0%})",
        xy=(0.05, 0.1),
        ha="left",
        xycoords="axes fraction",
    )
    ax.set_xlim(left=1e-9)
    ax.set_ylim(top=20)
    ax.axvline(1e-5, lw=1, linestyle=":", color="k")

ax.xaxis.set_visible(True)
ax.spines["bottom"].set_visible(True)
ax.set_xticks([1e-4, 1e-2, 1e-0])
ax.set_xticklabels(["0.01%", "1%", "100%"])
ax.set_xlabel("Relative Abundance")

# fig.subplots_adjust(hspace=-0.75)

In [None]:
def parse_taxonomy_string(taxonomy_string):
    values = taxonomy_string.split(";")
    return pd.Series(values, index=["d__", "p__", "c__", "o__", "f__", "g__", "s__"])

In [None]:
motu_taxonomy_inpath = "ref/uhgg_genomes_all_v2.tsv"

_motu_taxonomy = (
    pd.read_table(motu_taxonomy_inpath)[lambda x: x.Genome == x.Species_rep]
    .assign(species_id=lambda x: "1" + x.MGnify_accession.str.split("-").str[2])
    .set_index("species_id")
)

# motu_lineage_string = _motu_taxonomy.Lineage

motu_taxonomy = _motu_taxonomy.Lineage.apply(
    parse_taxonomy_string
)  # .assign(taxonomy_string=motu_lineage_string)
motu_taxonomy

In [None]:
for _species_id in top_motus.astype(str):
    print(_species_id, ":", ";".join(motu_taxonomy.loc[_species_id].values))

In [None]:
motu_taxonomy[lambda x: x.s__.str.endswith("hansenii")]

In [None]:
for _species_id in ["102544", "102506", "101303", "100150", "102330", "101704"]:
    print(
        _species_id,
        (motu_rabund[_species_id] > 0.0001).mean().round(2),
        (motu_rabund[_species_id] > 0.001).mean().round(2),
        motu_taxonomy.loc[_species_id].s__,
        sep="\t\t",
    )

In [None]:
for _species_id in ["100323", "101396", "101493", "102351"]:
    print(
        _species_id,
        (motu_rabund[_species_id] > 0.0001).mean().round(2),
        (motu_rabund[_species_id] > 0.001).mean().round(2),
        motu_taxonomy.loc[_species_id].s__,
        sep="\t\t",
    )

In [None]:
n_species = 20
top_motus = (
    (motu_rabund > 1e-3).sum().sort_values(ascending=False).head(n_species).index
)

fig, axs = plt.subplots(
    n_species, figsize=(5, 0.3 * n_species), sharex=True, sharey=True
)

bins = np.logspace(-8, 1, num=51)

for species_id, ax in zip(top_motus, axs):
    # ax.hist(rabund_subset[species_id], bins=bins, alpha=0.7)
    ax.hist(motu_rabund[species_id], bins=bins, alpha=0.7)
    ax.set_xscale("log")
    prevalence = (motu_rabund[species_id] > 1e-3).mean()
    ax.set_title("")
    # ax.set_xticks()
    # ax.set_yticks()
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.patch.set_alpha(0.0)
    for spine in ["left", "right", "top", "bottom"]:
        ax.spines[spine].set_visible(False)
    ax.annotate(
        f"{species_id} ({prevalence:0.0%})",
        xy=(0.05, 0.1),
        ha="left",
        xycoords="axes fraction",
    )
    ax.set_xlim(left=1e-9)
    ax.set_ylim(top=20)
    ax.axvline(1e-5, lw=1, linestyle=":", color="k")

ax.xaxis.set_visible(True)
ax.spines["bottom"].set_visible(True)
ax.set_xticks([1e-4, 1e-2, 1e-0])
ax.set_xticklabels(["0.01%", "1%", "100%"])
ax.set_xlabel("Relative Abundance")

# fig.subplots_adjust(hspace=-0.75)

In [None]:
sotu_depth = []
missing_files = []
for species_id in motu_depth.columns:
    path = f"data/group/een/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.comm.tsv"
    try:
        d = (
            pd.read_table(path, index_col=["sample", "strain"])
            .squeeze()
            .unstack()
            .rename(columns=str, index=lambda x: "CF_" + str(int(x.split("_")[1])))
            .rename({'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap.
        )
    except FileNotFoundError:
        missing_files.append(path)
        d = pd.DataFrame([])
    _keep_strains = idxwhere(d.sum() > 0.05)
    assert d.index.isin(motu_depth.index).all()
    d = d.reindex(index=motu_depth.index, columns=_keep_strains, fill_value=0)
    d = d.assign(__other=lambda x: 1 - x.sum(1)).rename(columns={"__other": -1})
    d[d < 0] = 0
    d = d.divide(d.sum(1), axis=0)
    d = d.multiply(motu_depth[species_id], axis=0)
    d = d.rename(columns=lambda s: f"{species_id}_{s}")
    sotu_depth.append(d)
sotu_depth = pd.concat(sotu_depth, axis=1)
sotu_rabund = sotu_depth.divide(sotu_depth.sum(1), axis=0)
len(motu_depth.columns), len(missing_files)

In [None]:
plt.hist(sotu_depth.values.flatten() + 1e-10, bins=np.logspace(-10, 0))
plt.yscale("log")
plt.xscale("log")

In [None]:
for _species_id in top_motus.astype(str):
    print(_species_id, ":", ";".join(motu_taxonomy.loc[_species_id].values))

In [None]:
x, y = align_indexes(sotu_rabund, motu_rabund)

x_linkage = linkage(x, method="average", metric="braycurtis", optimal_ordering=True)
y_linkage = linkage(y, method="average", metric="braycurtis", optimal_ordering=True)

colors = pd.DataFrame(
    dict(
        subj=sample.subject_id.map(subject_palette),
        swap=sample.index.to_series()
        .isin(["CF_11", "CF_15", "CF_431", "CF_427"])
        .replace({False: "grey", True: "black"}),
    )
)

x_pdist = pd.DataFrame(
    squareform(pdist(x, metric="braycurtis")), index=x.index, columns=x.index
)
sns.clustermap(
    x_pdist,
    row_linkage=y_linkage,
    col_linkage=x_linkage,
    row_colors=colors,
    col_colors=colors,
)

In [None]:
sample[["subject_id", "sample_type"]].value_counts().unstack(fill_value=0).sort_values(
    "human", ascending=False
).head(20)

In [None]:
sample[["subject_id", "diet_or_media"]].value_counts().unstack(
    fill_value=0
).sort_values("EEN", ascending=False).head(20)

## Understand Fermenter Expmt.

In [None]:
sample[lambda x: x.subject_id.isin(["A", "B", "H"]) & x.index.isin(motu_rabund.index)][
    ["subject_id", "source_samples", "sample_type", "diet_or_media", "status_mouse_inflamed"]
].value_counts(dropna=False).sort_index()

In [None]:
def _label_experiment_sample(x):
    if x.sample_type == "human":
        label = f"[{x.name}] {x.collection_date_relative_een_end} {x.diet_or_media}"
    elif x.sample_type in ["Fermenter_inoculum"]:
        label = f"[{x.name}] {x.source_samples} inoc {x.diet_or_media}"
    elif x.sample_type in ["Fermenter"]:
        label = f"[{x.name}] {x.source_samples} frmnt {x.diet_or_media}"
    elif x.sample_type in ["mouse"]:
        if x.status_mouse_inflamed == 'Inflamed':
            label = f"[{x.name}] {x.source_samples} 🐭 {x.mouse_genotype} {x.diet_or_media} inflam"
        elif x.status_mouse_inflamed == 'not_Inflamed':
            label = f"[{x.name}] {x.source_samples} 🐭 {x.mouse_genotype} {x.diet_or_media} not_inf"
        else:
            raise ValueError(f"sample type {x.status_mouse_inflamed} not understood")
    else:
        raise ValueError(f"sample type {x.sample_type} not understood")
    return label


sample.sort_values(
    [
        "subject_id",
        "collection_date_relative_een_end",
        "source_samples",
        "sample_type",
        "diet_or_media",
    ]
).assign(label=lambda d: d.apply(_label_experiment_sample, axis=1)).label

In [None]:
sample[
    lambda x: x.subject_id.isin(["A"]) & x.index.isin(motu_rabund.index)
].sort_values(
    [
        "subject_id",
        "collection_date_relative_een_end",
        "sample_type",
        "source_samples",
        "diet_or_media",
        "mouse_genotype",
        "status_mouse_inflamed",
    ]
).assign(
    label=lambda d: d.apply(_label_experiment_sample, axis=1)
)

In [None]:
sample[
    lambda x: x.subject_id.isin(["B"]) & x.index.isin(motu_rabund.index)
].sort_values(
    [
        "subject_id",
        "collection_date_relative_een_end",
        "source_samples",
        "sample_type",
        "diet_or_media",
        "mouse_genotype",
        "status_mouse_inflamed",
    ]
).assign(
    label=lambda d: d.apply(_label_experiment_sample, axis=1)
)

In [None]:
sample[
    lambda x: x.subject_id.isin(["H"]) & x.index.isin(motu_rabund.index)
].sort_values(
    [
        "subject_id",
        "collection_date_relative_een_end",
        "source_samples",
        "sample_type",
        "diet_or_media",
        "mouse_genotype",
        "status_mouse_inflamed",
    ]
).assign(
    label=lambda d: d.apply(_label_experiment_sample, axis=1)
)

## Confirm Sample Swap

In [None]:
d = sotu_rabund.loc[:, lambda x: (x > 1e-5).sum() > 1]
sample_linkage_strain_rabund = linkage(
    d,
    method="average",
    metric="braycurtis",
    optimal_ordering=True,
)
colors = pd.DataFrame(
    dict(
        subj=sample.subject_id.map(subject_palette),
        # type=sample.diet_or_media.map(diet_palette),
        swap=sample.index.to_series()
        .isin(["CF_11", "CF_15", "CF_431", "CF_427", "CF_402"])
        .map({True: "black", False: "grey"}),
    )
)

cg = sns.clustermap(
    d.T,
    norm=mpl.colors.PowerNorm(1 / 5),
    col_linkage=sample_linkage_strain_rabund,
    metric="cosine",
    xticklabels=1,
    figsize=(18, 10),
    col_colors=colors,
    dendrogram_ratio=(0.05, 0.05),
    yticklabels=0,
)
cg.ax_cbar.set_visible(False)

In [None]:
d = motu_rabund.loc[:, lambda x: (x > 1e-5).sum() > 1]
sample_linkage_strain_rabund = linkage(
    d,
    method="average",
    metric="braycurtis",
    optimal_ordering=True,
)
colors = pd.DataFrame(
    dict(
        subj=sample.subject_id.map(subject_palette),
        # type=sample.diet_or_media.map(diet_palette),
        swap=sample.index.to_series()
        .isin(["CF_11", "CF_15", "CF_431", "CF_427", "CF_402"])
        .map({True: "black", False: "grey"}),
    )
)

cg = sns.clustermap(
    d.T,
    norm=mpl.colors.PowerNorm(1 / 5),
    col_linkage=sample_linkage_strain_rabund,
    metric="cosine",
    xticklabels=1,
    figsize=(15, 10),
    col_colors=colors,
    dendrogram_ratio=(0.05, 0.05),
    yticklabels=0,
)
cg.ax_cbar.set_visible(False)

In [None]:
d = rotu_rabund.loc[:, lambda x: (x > 1e-5).sum() > 1]
sample_linkage_strain_rabund = linkage(
    d,
    method="average",
    metric="braycurtis",
    optimal_ordering=True,
)
colors = pd.DataFrame(
    dict(
        subj=sample.subject_id.map(subject_palette),
        # type=sample.diet_or_media.map(diet_palette),
        swap=sample.index.to_series()
        .isin(["CF_11", "CF_15", "CF_431", "CF_427", "CF_402"])
        .map({True: "black", False: "grey"}),
    )
)

cg = sns.clustermap(
    d.T,
    norm=mpl.colors.PowerNorm(1 / 5),
    col_linkage=sample_linkage_strain_rabund,
    metric="cosine",
    xticklabels=1,
    figsize=(85, 10),
    col_colors=colors,
    dendrogram_ratio=(0.05, 0.05),
    yticklabels=0,
)
cg.ax_cbar.set_visible(False)

# fig = plt.gcf()
# fig.savefig('fig/een_zotus_clustermap.png')

In [None]:
for k in subject_palette:
    plt.scatter([], [], color=subject_palette[k], label=k)
k = "other"
plt.scatter([], [], color=subject_palette[k], label=k)
plt.legend(ncols=5)

In [None]:
suspect_labels = ["CF_11", "CF_15", "CF_431", "CF_427", "CF_402"]
focal_subjects = list(sample.loc[suspect_labels].subject_id.unique())
focal_samples = idxwhere(
    sample.subject_id.isin(focal_subjects)
    & sample.index.to_series().isin(sotu_rabund.index)
)

d = sotu_rabund.loc[focal_samples].loc[:, lambda x: (x > 1e-5).sum() > 1]
sample_linkage_strain_rabund = linkage(
    d,
    method="average",
    metric="braycurtis",
    optimal_ordering=True,
)
colors = pd.DataFrame(
    dict(
        subj=sample.subject_id.map(subject_palette),
        # type=sample.diet_or_media.map(diet_palette),
        swap=sample.index.to_series()
        .isin(suspect_labels)
        .map({True: "black", False: "grey"}),
    )
)

cg = sns.clustermap(
    d.T,
    norm=mpl.colors.PowerNorm(1 / 5),
    col_linkage=sample_linkage_strain_rabund,
    metric="cosine",
    xticklabels=1,
    figsize=(15, 10),
    col_colors=colors,
    dendrogram_ratio=(0.05, 0.05),
    yticklabels=0,
)
cg.ax_cbar.set_visible(False)

In [None]:
suspect_labels = ["CF_11", "CF_15", "CF_431", "CF_427", "CF_402"]
focal_subjects = list(sample.loc[suspect_labels].subject_id.unique())
focal_samples = idxwhere(
    sample.subject_id.isin(focal_subjects)
    & sample.index.to_series().isin(rotu_rabund.index)
)

d = rotu_rabund.loc[focal_samples].loc[:, lambda x: (x > 1e-5).sum() > 1]
sample_linkage_strain_rabund = linkage(
    d,
    method="average",
    metric="braycurtis",
    optimal_ordering=True,
)
colors = pd.DataFrame(
    dict(
        subj=sample.subject_id.map(subject_palette),
        # type=sample.diet_or_media.map(diet_palette),
        swap=sample.index.to_series()
        .isin(suspect_labels)
        .map({True: "black", False: "grey"}),
    )
)

cg = sns.clustermap(
    d.T,
    norm=mpl.colors.PowerNorm(1 / 5),
    col_linkage=sample_linkage_strain_rabund,
    metric="cosine",
    xticklabels=1,
    figsize=(15, 10),
    col_colors=colors,
    dendrogram_ratio=(0.05, 0.05),
    yticklabels=0,
)
cg.ax_cbar.set_visible(False)

In [None]:
sample.loc[suspect_labels]

In [None]:
all_new_samples_list = [
    "CF_379",
    "CF_380",
    "CF_381",
    "CF_384",
    "CF_385",
    "CF_386",
    "CF_426",
    "CF_427",
    "CF_428",
    "CF_429",
    "CF_430",
    "CF_431",
    "CF_395",
    "CF_397",
    "CF_402",
    "CF_406",
    "CF_408",
    "CF_409",
    "CF_140",
    "CF_141",
    "CF_142",
    "CF_149",
    "CF_150",
    "CF_151",
    "CF_170",
    "CF_171",
    "CF_172",
    "CF_173",
    "CF_174",
    "CF_175",
    "CF_152",
    "CF_153",
    "CF_154",
    "CF_155",
    "CF_156",
    "CF_157",
    "CF_115",
    "CF_116",
    "CF_117",
    "CF_118",
    "CF_119",
    "CF_120",
    "CF_127",
    "CF_128",
    "CF_130",
    "CF_131",
    "CF_132",
    "CF_133",
    "CF_667",
    "CF_668",
    "CF_669",
    "CF_670",
    "CF_671",
    "CF_672",
]

sample.loc[all_new_samples_list]

## EEN vs. PostEEN Species Enrichment Analysis

In [None]:
d0 = motu_rabund.join(sample[['subject_id', 'diet_or_media']])[lambda x: x.diet_or_media.isin(['EEN', 'PostEEN'])].groupby(['subject_id', 'diet_or_media']).mean()

species_enrichment_stats = []
for species in d0.columns:
    d1 = d0[species].unstack('diet_or_media').dropna()
    prevalence = (d1 > 0).mean()
    d1_pc = (d1 + 0.00001)
    log2_fold_change = np.log2(d1_pc.PostEEN / d1_pc.EEN)
    if (log2_fold_change == 0).all():
        pval = 1
    else:
        pval = sp.stats.wilcoxon(d1.PostEEN, d1.EEN)[1]
        
    species_enrichment_stats.append((species, prevalence.PostEEN, prevalence.EEN, log2_fold_change.mean(), pval))

species_enrichment_stats = pd.DataFrame(species_enrichment_stats, columns=['species_id', 'prevalence_post', 'prevalence_een', 'mean_log2_fold_change', 'pvalue']).set_index('species_id')

In [None]:
species_enrichment_stats.loc['102506']

In [None]:
species_enrichment_stats_with_fdr = species_enrichment_stats[lambda x: (x.prevalence_post > 0.1) | (x.prevalence_een > 0.1)].assign(fdr=lambda x: fdrcorrection(x.pvalue)[1]).sort_values('pvalue')
species_enrichment_stats_with_fdr[lambda x: x.pvalue < 0.01].join(motu_taxonomy.s__)

In [None]:
species_enrichment_stats_with_fdr.sort_values('pvalue').head(20)

## Statistical Linkage Between zOTUs and GT-Pro Species

### "Enterobacteriaceae"

In [None]:
focal_family_list = ["Enterobacteriaceae"]
# rotu_taxonomy.loc[focal_zotu].str.split(";").str[-3].unique()
assert len(focal_family_list) == 1
focal_family = focal_family_list[0]

In [None]:
x_taxa = idxwhere(rotu_taxonomy.str.contains(f";{focal_family};"))
y_taxa = idxwhere(
    motu_taxonomy.apply(lambda x: ";".join(x), axis=1).str.contains(
        f"f__{focal_family};"
    )
    & motu_taxonomy.index.to_series().isin(motu_rabund.columns)
)

In [None]:
rotu_taxonomy.loc[x_taxa].reset_index().values

In [None]:
motu_taxonomy.loc[y_taxa]

In [None]:
from sklearn.cross_decomposition import PLSCanonical, PLSRegression

x, y = align_indexes(rotu_rabund[x_taxa], motu_rabund[y_taxa])
pls = PLSCanonical(n_components=min(*x.shape, *y.shape), scale=False, tol=1e-7).fit(
    x, y
)

In [None]:
pseudo_map = pd.DataFrame(
    (pls.y_rotations_ @ pls.x_rotations_.T), index=y.columns, columns=x.columns
)
pseudo_map.shape

In [None]:
plt.hist(pseudo_map.values.flatten(), bins=1000)
plt.yscale("log")
None

In [None]:
cross_correlation = pd.DataFrame(
    1
    - sp.spatial.distance.cdist(
        x.T, y.T, metric=lambda x, y: 1 - sp.stats.spearmanr(x, y)[0]
    ),
    columns=y.columns,
    index=x.columns,
)

In [None]:
sns.clustermap(
    pseudo_map.fillna(0).T, center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)
sns.clustermap(
    cross_correlation.fillna(0), center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)

In [None]:
d = pd.DataFrame(
    dict(
        pls=pseudo_map.stack(),
        cc=cross_correlation.T.stack(),
        pls_rank_A=pseudo_map.rank(axis=0, ascending=False).stack(),
        pls_rank_B=pseudo_map.rank(axis=1, ascending=False).stack(),
        cc_rank_A=cross_correlation.T.rank(axis=0, ascending=False).stack(),
        cc_rank_B=cross_correlation.T.rank(axis=1, ascending=False).stack(),
    )
)

#### Zotu4

In [None]:
focal_zotu = ["Zotu4"]

(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

#### 102351

In [None]:
focal_motu = '102351'

(
    d.reorder_levels(["species_id", "zotu"])
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .join(rotu_taxonomy.str[-20:])
    .loc[focal_motu]
)

### "Eggerthellaceae" (Zotu172)

In [None]:
focal_zotu = ["Zotu172"]
focal_family_list = rotu_taxonomy.loc[focal_zotu].str.split(";").str[-3].unique()
assert len(focal_family_list) == 1
focal_family = focal_family_list[0]
print(focal_family)

In [None]:
x_taxa = idxwhere(rotu_taxonomy.str.contains(f";{focal_family};"))
y_taxa = idxwhere(
    motu_taxonomy.apply(lambda x: ";".join(x), axis=1).str.contains(
        f"f__{focal_family};"
    )
    & motu_taxonomy.index.to_series().isin(motu_rabund.columns)
)

In [None]:
rotu_taxonomy.loc[x_taxa].reset_index().values

In [None]:
motu_taxonomy.loc[y_taxa]

In [None]:
from sklearn.cross_decomposition import PLSCanonical, PLSRegression

x, y = align_indexes(rotu_rabund[x_taxa], motu_rabund[y_taxa])
pls = PLSCanonical(n_components=min(*x.shape, *y.shape), scale=False, tol=1e-7).fit(
    x, y
)

In [None]:
pseudo_map = pd.DataFrame(
    (pls.y_rotations_ @ pls.x_rotations_.T), index=y.columns, columns=x.columns
)
pseudo_map.shape

In [None]:
plt.hist(pseudo_map.values.flatten(), bins=1000)
plt.yscale("log")
None

In [None]:
cross_correlation = pd.DataFrame(
    1
    - sp.spatial.distance.cdist(
        x.T, y.T, metric=lambda x, y: 1 - sp.stats.spearmanr(x, y)[0]
    ),
    columns=y.columns,
    index=x.columns,
)

In [None]:
sns.clustermap(
    pseudo_map.fillna(0).T, center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)
sns.clustermap(
    cross_correlation.fillna(0), center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)

In [None]:
d = pd.DataFrame(
    dict(
        pls=pseudo_map.stack(),
        cc=cross_correlation.T.stack(),
        pls_rank_A=pseudo_map.rank(axis=0, ascending=False).stack(),
        pls_rank_B=pseudo_map.rank(axis=1, ascending=False).stack(),
        cc_rank_A=cross_correlation.T.rank(axis=0, ascending=False).stack(),
        cc_rank_B=cross_correlation.T.rank(axis=1, ascending=False).stack(),
    )
)

In [None]:
(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

### "Lachnospiraceae"

#### Zotu114

In [None]:
focal_zotu = ["Zotu114"]
focal_family_list = rotu_taxonomy.loc[focal_zotu].str.split(";").str[-3].unique()
assert len(focal_family_list) == 1
focal_family = focal_family_list[0]
print(focal_family)

In [None]:
x_taxa = idxwhere(rotu_taxonomy.str.contains(f";{focal_family};"))
y_taxa = idxwhere(
    motu_taxonomy.apply(lambda x: ";".join(x), axis=1).str.contains(
        f"f__{focal_family};"
    )
    & motu_taxonomy.index.to_series().isin(motu_rabund.columns)
)

In [None]:
rotu_taxonomy.loc[x_taxa].reset_index().values

In [None]:
motu_taxonomy.loc[y_taxa]

In [None]:
from sklearn.cross_decomposition import PLSCanonical, PLSRegression

x, y = align_indexes(rotu_rabund[x_taxa], motu_rabund[y_taxa])
pls = PLSCanonical(n_components=min(*x.shape, *y.shape), scale=False, tol=1e-7).fit(
    x, y
)

In [None]:
pseudo_map = pd.DataFrame(
    (pls.y_rotations_ @ pls.x_rotations_.T), index=y.columns, columns=x.columns
)
pseudo_map.shape

In [None]:
plt.hist(pseudo_map.values.flatten(), bins=1000)
plt.yscale("log")
None

In [None]:
cross_correlation = pd.DataFrame(
    1
    - sp.spatial.distance.cdist(
        x.T, y.T, metric=lambda x, y: 1 - sp.stats.spearmanr(x, y)[0]
    ),
    columns=y.columns,
    index=x.columns,
)

In [None]:
sns.clustermap(
    pseudo_map.fillna(0).T, center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)
sns.clustermap(
    cross_correlation.fillna(0), center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)

In [None]:
d = pd.DataFrame(
    dict(
        pls=pseudo_map.stack(),
        cc=cross_correlation.T.stack(),
        pls_rank_A=pseudo_map.rank(axis=0, ascending=False).stack(),
        pls_rank_B=pseudo_map.rank(axis=1, ascending=False).stack(),
        cc_rank_A=cross_correlation.T.rank(axis=0, ascending=False).stack(),
        cc_rank_B=cross_correlation.T.rank(axis=1, ascending=False).stack(),
    )
)

In [None]:
(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

#### Zotu25

In [None]:
focal_zotu = ["Zotu25"]

(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

#### Zotu5

In [None]:
focal_zotu = ["Zotu5"]

(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

#### Zotu10

In [None]:
focal_zotu = ["Zotu10"]

(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

#### Zotu160

In [None]:
focal_zotu = ["Zotu160"]

(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

#### Zotu31

In [None]:
focal_zotu = ["Zotu31"]

(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

#### Zotu155

In [None]:
focal_zotu = ["Zotu155"]

(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

#### Zotu75

In [None]:
focal_zotu = ["Zotu75"]

(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

#### Zotu13

In [None]:
focal_zotu = ["Zotu13"]

(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

#### Zotu155

In [None]:
focal_zotu = ["Zotu155"]

(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

#### Zotu27

In [None]:
focal_zotu = ["Zotu27"]

(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

#### Zotu104

In [None]:
focal_zotu = ["Zotu104"]

(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

##### Alternative zOTUs

In [None]:
(
    d
    .loc[['100205']]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

#### Zotu167

In [None]:
focal_zotu = ["Zotu167"]

(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("cc", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

#### Zotu26

In [None]:
focal_zotu = ["Zotu26"]

(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

In [None]:
focal_zotu = ["Zotu100"]
focal_family_list = rotu_taxonomy.loc[focal_zotu].str.split(";").str[-3].unique()
assert len(focal_family_list) == 1
focal_family = focal_family_list[0]
print(focal_family)

### "Erysipelatoclostridiaceae" (Zotu46)

In [None]:
focal_zotu = ["Zotu46"]
focal_family_list = rotu_taxonomy.loc[focal_zotu].str.split(";").str[-3].unique()
assert len(focal_family_list) == 1
focal_family = focal_family_list[0]
print(focal_family)

In [None]:
x_taxa = idxwhere(rotu_taxonomy.str.contains(f";{focal_family};"))
y_taxa = idxwhere(
    motu_taxonomy.apply(lambda x: ";".join(x), axis=1).str.contains(
        f"f__{focal_family};"
    )
    & motu_taxonomy.index.to_series().isin(motu_rabund.columns)
)

In [None]:
rotu_taxonomy.loc[x_taxa].reset_index().values

In [None]:
motu_taxonomy.loc[y_taxa]

In [None]:
from sklearn.cross_decomposition import PLSCanonical, PLSRegression

x, y = align_indexes(rotu_rabund[x_taxa], motu_rabund[y_taxa])
pls = PLSCanonical(n_components=min(*x.shape, *y.shape), scale=False, tol=1e-7).fit(
    x, y
)

In [None]:
pseudo_map = pd.DataFrame(
    (pls.y_rotations_ @ pls.x_rotations_.T), index=y.columns, columns=x.columns
)
pseudo_map.shape

In [None]:
plt.hist(pseudo_map.values.flatten(), bins=1000)
plt.yscale("log")
None

In [None]:
cross_correlation = pd.DataFrame(
    1
    - sp.spatial.distance.cdist(
        x.T, y.T, metric=lambda x, y: 1 - sp.stats.spearmanr(x, y)[0]
    ),
    columns=y.columns,
    index=x.columns,
)

In [None]:
sns.clustermap(
    pseudo_map.fillna(0).T, center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)
sns.clustermap(
    cross_correlation.fillna(0), center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)

In [None]:
d = pd.DataFrame(
    dict(
        pls=pseudo_map.stack(),
        cc=cross_correlation.T.stack(),
        pls_rank_A=pseudo_map.rank(axis=0, ascending=False).stack(),
        pls_rank_B=pseudo_map.rank(axis=1, ascending=False).stack(),
        cc_rank_A=cross_correlation.T.rank(axis=0, ascending=False).stack(),
        cc_rank_B=cross_correlation.T.rank(axis=1, ascending=False).stack(),
    )
)

In [None]:
(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

### "Oscillospiraceae" (Zotu49)

In [None]:
focal_zotu = ["Zotu49"]
focal_family_list = rotu_taxonomy.loc[focal_zotu].str.split(";").str[-3].unique()
assert len(focal_family_list) == 1
focal_family = focal_family_list[0]
print(focal_family)

In [None]:
x_taxa = idxwhere(rotu_taxonomy.str.contains(f";{focal_family};"))
y_taxa = idxwhere(
    motu_taxonomy.apply(lambda x: ";".join(x), axis=1).str.contains(
        f"f__{focal_family};"
    )
    & motu_taxonomy.index.to_series().isin(motu_rabund.columns)
)

In [None]:
rotu_taxonomy.loc[x_taxa].reset_index().values

In [None]:
motu_taxonomy.loc[y_taxa]

In [None]:
from sklearn.cross_decomposition import PLSCanonical, PLSRegression

x, y = align_indexes(rotu_rabund[x_taxa], motu_rabund[y_taxa])
pls = PLSCanonical(n_components=min(*x.shape, *y.shape), scale=False, tol=1e-7).fit(
    x, y
)

In [None]:
pseudo_map = pd.DataFrame(
    (pls.y_rotations_ @ pls.x_rotations_.T), index=y.columns, columns=x.columns
)
pseudo_map.shape

In [None]:
plt.hist(pseudo_map.values.flatten(), bins=1000)
plt.yscale("log")
None

In [None]:
cross_correlation = pd.DataFrame(
    1
    - sp.spatial.distance.cdist(
        x.T, y.T, metric=lambda x, y: 1 - sp.stats.spearmanr(x, y)[0]
    ),
    columns=y.columns,
    index=x.columns,
)

In [None]:
sns.clustermap(
    pseudo_map.fillna(0).T, center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)
sns.clustermap(
    cross_correlation.fillna(0), center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)

In [None]:
d = pd.DataFrame(
    dict(
        pls=pseudo_map.stack(),
        cc=cross_correlation.T.stack(),
        pls_rank_A=pseudo_map.rank(axis=0, ascending=False).stack(),
        pls_rank_B=pseudo_map.rank(axis=1, ascending=False).stack(),
        cc_rank_A=cross_correlation.T.rank(axis=0, ascending=False).stack(),
        cc_rank_B=cross_correlation.T.rank(axis=1, ascending=False).stack(),
    )
)

In [None]:
(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

### "Erysipelotrichaceae" (Zotu34)

In [None]:
focal_zotu = ["Zotu34"]
focal_family_list = rotu_taxonomy.loc[focal_zotu].str.split(";").str[-3].unique()
assert len(focal_family_list) == 1
focal_family = focal_family_list[0]
print(focal_family)

In [None]:
x_taxa = idxwhere(rotu_taxonomy.str.contains(f";{focal_family};"))
y_taxa = idxwhere(
    motu_taxonomy.apply(lambda x: ";".join(x), axis=1).str.contains(
        f"f__{focal_family};"
    )
    & motu_taxonomy.index.to_series().isin(motu_rabund.columns)
)

In [None]:
rotu_taxonomy.loc[x_taxa].reset_index().values

In [None]:
motu_taxonomy.loc[y_taxa]

In [None]:
from sklearn.cross_decomposition import PLSCanonical, PLSRegression

x, y = align_indexes(rotu_rabund[x_taxa], motu_rabund[y_taxa])
pls = PLSCanonical(n_components=min(*x.shape, *y.shape), scale=False, tol=1e-7).fit(
    x, y
)

In [None]:
pseudo_map = pd.DataFrame(
    (pls.y_rotations_ @ pls.x_rotations_.T), index=y.columns, columns=x.columns
)
pseudo_map.shape

In [None]:
plt.hist(pseudo_map.values.flatten(), bins=1000)
plt.yscale("log")
None

In [None]:
cross_correlation = pd.DataFrame(
    1
    - sp.spatial.distance.cdist(
        x.T, y.T, metric=lambda x, y: 1 - sp.stats.spearmanr(x, y)[0]
    ),
    columns=y.columns,
    index=x.columns,
)

In [None]:
sns.clustermap(
    pseudo_map.fillna(0).T, center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)
sns.clustermap(
    cross_correlation.fillna(0), center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)

In [None]:
d = pd.DataFrame(
    dict(
        pls=pseudo_map.stack(),
        cc=cross_correlation.T.stack(),
        pls_rank_A=pseudo_map.rank(axis=0, ascending=False).stack(),
        pls_rank_B=pseudo_map.rank(axis=1, ascending=False).stack(),
        cc_rank_A=cross_correlation.T.rank(axis=0, ascending=False).stack(),
        cc_rank_B=cross_correlation.T.rank(axis=1, ascending=False).stack(),
    )
)

In [None]:
(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

### "Ruminococcaceae"

#### Zotu78

In [None]:
focal_zotu = ["Zotu78"]
focal_family_list = rotu_taxonomy.loc[focal_zotu].str.split(";").str[-3].unique()
assert len(focal_family_list) == 1
focal_family = focal_family_list[0]
print(focal_family)

In [None]:
x_taxa = idxwhere(rotu_taxonomy.str.contains(f";{focal_family};"))
y_taxa = idxwhere(
    motu_taxonomy.apply(lambda x: ";".join(x), axis=1).str.contains(
        f"f__{focal_family};"
    )
    & motu_taxonomy.index.to_series().isin(motu_rabund.columns)
)

In [None]:
rotu_taxonomy.loc[x_taxa].reset_index().values

In [None]:
motu_taxonomy.loc[y_taxa]

In [None]:
from sklearn.cross_decomposition import PLSCanonical, PLSRegression

x, y = align_indexes(rotu_rabund[x_taxa], motu_rabund[y_taxa])
pls = PLSCanonical(n_components=min(*x.shape, *y.shape), scale=False, tol=1e-7).fit(
    x, y
)

In [None]:
pseudo_map = pd.DataFrame(
    (pls.y_rotations_ @ pls.x_rotations_.T), index=y.columns, columns=x.columns
)
pseudo_map.shape

In [None]:
plt.hist(pseudo_map.values.flatten(), bins=1000)
plt.yscale("log")
None

In [None]:
cross_correlation = pd.DataFrame(
    1
    - sp.spatial.distance.cdist(
        x.T, y.T, metric=lambda x, y: 1 - sp.stats.spearmanr(x, y)[0]
    ),
    columns=y.columns,
    index=x.columns,
)

In [None]:
sns.clustermap(
    pseudo_map.fillna(0).T, center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)
sns.clustermap(
    cross_correlation.fillna(0), center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)

In [None]:
d = pd.DataFrame(
    dict(
        pls=pseudo_map.stack(),
        cc=cross_correlation.T.stack(),
        pls_rank_A=pseudo_map.rank(axis=0, ascending=False).stack(),
        pls_rank_B=pseudo_map.rank(axis=1, ascending=False).stack(),
        cc_rank_A=cross_correlation.T.rank(axis=0, ascending=False).stack(),
        cc_rank_B=cross_correlation.T.rank(axis=1, ascending=False).stack(),
    )
)

In [None]:
(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

#### Zotu16

In [None]:
focal_zotu = ["Zotu16"]
(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

In [None]:
motu_taxonomy.g__.loc['102040']

#### Zotu9

In [None]:
focal_zotu = ["Zotu9"]
(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

In [None]:
motu_taxonomy.loc['103166']

### "Oscillospiraceae" (Zotu13)

In [None]:
focal_zotu = ["Zotu13"]
focal_family_list = rotu_taxonomy.loc[focal_zotu].str.split(";").str[-3].unique()
assert len(focal_family_list) == 1
focal_family = focal_family_list[0]
print(focal_family)

In [None]:
x_taxa = idxwhere(rotu_taxonomy.str.contains(f";{focal_family};"))
y_taxa = idxwhere(
    motu_taxonomy.apply(lambda x: ";".join(x), axis=1).str.contains(
        f"f__{focal_family};"
    )
    & motu_taxonomy.index.to_series().isin(motu_rabund.columns)
)

In [None]:
rotu_taxonomy.loc[x_taxa].reset_index().values

In [None]:
motu_taxonomy.loc[y_taxa]

In [None]:
from sklearn.cross_decomposition import PLSCanonical, PLSRegression

x, y = align_indexes(rotu_rabund[x_taxa], motu_rabund[y_taxa])
pls = PLSCanonical(n_components=min(*x.shape, *y.shape), scale=False, tol=1e-7).fit(
    x, y
)

In [None]:
pseudo_map = pd.DataFrame(
    (pls.y_rotations_ @ pls.x_rotations_.T), index=y.columns, columns=x.columns
)
pseudo_map.shape

In [None]:
plt.hist(pseudo_map.values.flatten(), bins=1000)
plt.yscale("log")
None

In [None]:
cross_correlation = pd.DataFrame(
    1
    - sp.spatial.distance.cdist(
        x.T, y.T, metric=lambda x, y: 1 - sp.stats.spearmanr(x, y)[0]
    ),
    columns=y.columns,
    index=x.columns,
)

In [None]:
sns.clustermap(
    pseudo_map.fillna(0).T, center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)
sns.clustermap(
    cross_correlation.fillna(0), center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)

In [None]:
d = pd.DataFrame(
    dict(
        pls=pseudo_map.stack(),
        cc=cross_correlation.T.stack(),
        pls_rank_A=pseudo_map.rank(axis=0, ascending=False).stack(),
        pls_rank_B=pseudo_map.rank(axis=1, ascending=False).stack(),
        cc_rank_A=cross_correlation.T.rank(axis=0, ascending=False).stack(),
        cc_rank_B=cross_correlation.T.rank(axis=1, ascending=False).stack(),
    )
)

In [None]:
(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

### "Bacteroidaceae"

####  Zotu74

In [None]:
focal_zotu = ["Zotu74"]
focal_family_list = rotu_taxonomy.loc[focal_zotu].str.split(";").str[-3].unique()
assert len(focal_family_list) == 1
focal_family = focal_family_list[0]
print(focal_family)

In [None]:
x_taxa = idxwhere(rotu_taxonomy.str.contains(f";{focal_family};"))
y_taxa = idxwhere(
    motu_taxonomy.apply(lambda x: ";".join(x), axis=1).str.contains(
        f"f__{focal_family};"
    )
    & motu_taxonomy.index.to_series().isin(motu_rabund.columns)
)

In [None]:
rotu_taxonomy.loc[x_taxa].reset_index().values

In [None]:
motu_taxonomy.loc[y_taxa]

In [None]:
from sklearn.cross_decomposition import PLSCanonical, PLSRegression

x, y = align_indexes(rotu_rabund[x_taxa], motu_rabund[y_taxa])
pls = PLSCanonical(n_components=min(*x.shape, *y.shape), scale=False, tol=1e-7).fit(
    x, y
)

In [None]:
pseudo_map = pd.DataFrame(
    (pls.y_rotations_ @ pls.x_rotations_.T), index=y.columns, columns=x.columns
)
pseudo_map.shape

In [None]:
plt.hist(pseudo_map.values.flatten(), bins=1000)
plt.yscale("log")
None

In [None]:
cross_correlation = pd.DataFrame(
    1
    - sp.spatial.distance.cdist(
        x.T, y.T, metric=lambda x, y: 1 - sp.stats.spearmanr(x, y)[0]
    ),
    columns=y.columns,
    index=x.columns,
)

In [None]:
sns.clustermap(
    pseudo_map.fillna(0).T, center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)
sns.clustermap(
    cross_correlation.fillna(0), center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)

In [None]:
d = pd.DataFrame(
    dict(
        pls=pseudo_map.stack(),
        cc=cross_correlation.T.stack(),
        pls_rank_A=pseudo_map.rank(axis=0, ascending=False).stack(),
        pls_rank_B=pseudo_map.rank(axis=1, ascending=False).stack(),
        cc_rank_A=cross_correlation.T.rank(axis=0, ascending=False).stack(),
        cc_rank_B=cross_correlation.T.rank(axis=1, ascending=False).stack(),
    )
)

In [None]:
(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

##### Deeper dive into alternative Zotus to the same species

In [None]:
(
    d
    .loc[['102549']]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

#### Zotu6

In [None]:
focal_zotu = ["Zotu6"]
(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

In [None]:
motu_taxonomy[['g__', 's__']].loc['101346']

#### Zotu12

In [None]:
focal_zotu = ["Zotu12"]
(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

In [None]:
motu_taxonomy[['g__', 's__']].loc['101337']

#### 102478

In [None]:
focal_motu = '102478'

print(motu_taxonomy.loc[focal_motu])

(
    d.reorder_levels(["species_id", "zotu"])
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .join(rotu_taxonomy.str[-20:])
    .loc[focal_motu]
)

#### 101378

In [None]:
focal_motu = '101378'

print(motu_taxonomy.loc[focal_motu])

(
    d.reorder_levels(["species_id", "zotu"])
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .join(rotu_taxonomy.str[-20:])
    .loc[focal_motu]
)

In [None]:
motu_taxonomy.loc[focal_motu]

### "Peptostreptococcaceae" (Zotu100)

In [None]:
focal_zotu = ["Zotu100"]
focal_family_list = rotu_taxonomy.loc[focal_zotu].str.split(";").str[-3].unique()
assert len(focal_family_list) == 1
focal_family = focal_family_list[0]
print(focal_family)

In [None]:
x_taxa = idxwhere(rotu_taxonomy.str.contains(f";{focal_family};"))
y_taxa = idxwhere(
    motu_taxonomy.apply(lambda x: ";".join(x), axis=1).str.contains(
        f"f__{focal_family};"
    )
    & motu_taxonomy.index.to_series().isin(motu_rabund.columns)
)

In [None]:
rotu_taxonomy.loc[x_taxa].reset_index().values

In [None]:
motu_taxonomy.loc[y_taxa]

In [None]:
from sklearn.cross_decomposition import PLSCanonical, PLSRegression

x, y = align_indexes(rotu_rabund[x_taxa], motu_rabund[y_taxa])
pls = PLSCanonical(n_components=min(*x.shape, *y.shape), scale=False, tol=1e-7).fit(
    x, y
)

In [None]:
pseudo_map = pd.DataFrame(
    (pls.y_rotations_ @ pls.x_rotations_.T), index=y.columns, columns=x.columns
)
pseudo_map.shape

In [None]:
plt.hist(pseudo_map.values.flatten(), bins=1000)
plt.yscale("log")
None

In [None]:
cross_correlation = pd.DataFrame(
    1
    - sp.spatial.distance.cdist(
        x.T, y.T, metric=lambda x, y: 1 - sp.stats.spearmanr(x, y)[0]
    ),
    columns=y.columns,
    index=x.columns,
)

In [None]:
sns.clustermap(
    pseudo_map.fillna(0).T, center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)
sns.clustermap(
    cross_correlation.fillna(0), center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)

In [None]:
d = pd.DataFrame(
    dict(
        pls=pseudo_map.stack(),
        cc=cross_correlation.T.stack(),
        pls_rank_A=pseudo_map.rank(axis=0, ascending=False).stack(),
        pls_rank_B=pseudo_map.rank(axis=1, ascending=False).stack(),
        cc_rank_A=cross_correlation.T.rank(axis=0, ascending=False).stack(),
        cc_rank_B=cross_correlation.T.rank(axis=1, ascending=False).stack(),
    )
)

In [None]:
(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

### "Enterococcaceae" (100323)

In [None]:
focal_motu = ["100323"]
focal_family_list = list(motu_taxonomy.loc[focal_motu].f__.str[3:].values)
focal_family_list
assert len(focal_family_list) == 1
focal_family = focal_family_list[0]
print(focal_family)

In [None]:
x_taxa = idxwhere(rotu_taxonomy.str.contains(f";{focal_family};"))
y_taxa = idxwhere(
    motu_taxonomy.apply(lambda x: ";".join(x), axis=1).str.contains(
        f"f__{focal_family};"
    )
    & motu_taxonomy.index.to_series().isin(motu_rabund.columns)
)

In [None]:
rotu_taxonomy.loc[x_taxa].reset_index().values

In [None]:
motu_taxonomy.loc[y_taxa]

In [None]:
from sklearn.cross_decomposition import PLSCanonical, PLSRegression

x, y = align_indexes(rotu_rabund[x_taxa], motu_rabund[y_taxa])
pls = PLSCanonical(n_components=min(*x.shape, *y.shape), scale=False, tol=1e-7).fit(
    x, y
)

In [None]:
pseudo_map = pd.DataFrame(
    (pls.y_rotations_ @ pls.x_rotations_.T), index=y.columns, columns=x.columns
)
pseudo_map.shape

In [None]:
plt.hist(pseudo_map.values.flatten(), bins=1000)
plt.yscale("log")
None

In [None]:
cross_correlation = pd.DataFrame(
    1
    - sp.spatial.distance.cdist(
        x.T, y.T, metric=lambda x, y: 1 - sp.stats.spearmanr(x, y)[0]
    ),
    columns=y.columns,
    index=x.columns,
)

In [None]:
sns.clustermap(
    pseudo_map.fillna(0).T, center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)
sns.clustermap(
    cross_correlation.fillna(0), center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)

In [None]:
d = pd.DataFrame(
    dict(
        pls=pseudo_map.stack(),
        cc=cross_correlation.T.stack(),
        pls_rank_A=pseudo_map.rank(axis=0, ascending=False).stack(),
        pls_rank_B=pseudo_map.rank(axis=1, ascending=False).stack(),
        cc_rank_A=cross_correlation.T.rank(axis=0, ascending=False).stack(),
        cc_rank_B=cross_correlation.T.rank(axis=1, ascending=False).stack(),
    )
)

In [None]:
(
    d.reorder_levels(["species_id", "zotu"])
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .join(rotu_taxonomy.str[-20:])
    .loc[focal_motu]
)

## Focal Species Plots

In [None]:
def strains_in_subjects(
    _species_rabund,
    # _species_id,
    _world,
    _meta,
    # _species_taxonomy=motu_taxonomy,
    _species_rabundB=None,
    savefig=False,
    plt_mean_key=False,
    subject_order=subject_order,
    # _outpath_pattern="fig/een_{_species_name}_strain_tracking.pdf",
    ncols=2,
    ax_width=7,
    ax_height=4,
):
    # print(_species_taxonomy.loc[_species_id])
    # _species_name = (
    #     _species_taxonomy.loc[_species_id].s__[len("s__") :].replace(" ", "_")
    # )
    _frac = _world.drop_low_abundance_strains(0.05).community.to_pandas()

    # # Arbitrarily ordered palette:
    # _strain_list = list(_frac.columns)
    # _strain_list.remove(-1)  # Drop "other" strain.
    # strain_palette = lib.plot.construct_ordered_palette(_strain_list, cm="rainbow")

    # Genotype similarity ordered palette:
    _world_drop_low_abund = _world.drop_low_abundance_strains(0.05)
    _strain_list = list(
        linkage_order(
            _world_drop_low_abund.genotype.linkage(optimal_ordering=True),
            _world_drop_low_abund.strain.values,
        )
    )
    _strain_list.remove(-1)  # Drop "other" strain.
    strain_palette = lib.plot.construct_ordered_palette(
        _strain_list,
        cm="rainbow",
    )

    # # Construct an ordered palette, but use the order to AVOID closely correlated strains (e.g. found in the same subject)
    # # showing up with similar colors.
    # _strain_list = linkage_order(
    #     sp.cluster.hierarchy.linkage(
    #         _frac.groupby(meta.subject_id).mean().T,
    #         method="average",
    #         metric="cosine",
    #         optimal_ordering=True,
    #     ),
    #     index=_frac.columns,
    # )
    # _strain_list.remove(-1)  # Drop "other" strain.
    # strain_palette = lib.plot.construct_ordered_palette(
    #     maximally_shuffled_order(
    #         _strain_list
    #     ),
    #     cm="rainbow",
    # )

    d0 = (
        _meta.assign(
            # has_strain_deconv=lambda x: x.index.isin(_comm.index),
            species_rabund=_species_rabund,
            species_rabundB=_species_rabundB,
        )
        .join(_frac)
        .sort_values(
            [
                "subject_id",
                "collection_date_relative_een_end",
                "sample_type",
                "source_samples",
                "diet_or_media",
                "mouse_genotype",
                "status_mouse_inflamed",
            ]
        )
    )
    xlim = d0.subject_id.value_counts().max()

    _here_subject_list = _meta.subject_id.unique()
    _subject_order = [s for s in subject_order if s in _here_subject_list]
    fig, axs = lib.plot.subplots_grid(
        ncols=ncols,
        naxes=len(_subject_order),
        ax_width=ax_width,
        ax_height=ax_height,
    )
    # fig.suptitle(_species_name)
    for subject_id, ax in zip(_subject_order, axs.flatten()):
        ax.set_title(subject_id)
        twin_ax = ax.twinx()
        d1 = d0[lambda x: x.subject_id == subject_id].set_index("label")
        if d1.empty:
            continue
        d1[[-1] + _strain_list].plot(
            kind="bar",
            width=0.95,
            stacked=True,
            color=strain_palette,
            ax=ax,
            edgecolor="k",
            lw=0.5,
        )
        d1.species_rabund.plot(
            kind="line", ax=twin_ax, color="k", marker=".", linestyle="-"
        )
        if _species_rabundB is not None:
            d1.species_rabundB.plot(
                kind="line", ax=twin_ax, color="midnightblue", marker=".", linestyle=":"
            )
        ax.legend_.set_visible(False)
        ax.set_ylim(0, 1)
        ax.set_ylabel("strain fraction")
        ax.set_xlabel("")
        twin_ax.set_ylabel("species relative abundance")
        twin_ax.set_ylim(0)
        lib.plot.rotate_xticklabels(ax)
        ax.set_xlim(-0.5, xlim + 0.5)
    fig.tight_layout()

    if savefig:
        fig.savefig(savefig)

    # Plot colorbars for each subject showing strain abundances.
    if plt_mean_key:
        fig, ax = plt.subplots()
        d2 = d0.groupby("subject_id")[_frac.columns].mean().reindex(_subject_order)
        d2[[-1] + _strain_list].plot(
            kind="bar",
            width=0.95,
            stacked=True,
            color=strain_palette,
            ax=ax,
            edgecolor="k",
            lw=0.5,
        )
        ax.legend_.set_visible(False)

In [None]:
def compete_two_fits(_world0, _world1, plot_npos=1000, low_abund_thresh=0.05):
    w0 = _world0.random_sample(
        position=min(_world0.sizes["position"], plot_npos)
    ).drop_low_abundance_strains(low_abund_thresh)
    w1 = _world1.sel(position=w0.position, sample=w0.sample).drop_low_abundance_strains(
        0.05
    )

    sf.plot.plot_metagenotype(
        w0, col_linkage_func=lambda w: w0.metagenotype.linkage(optimal_ordering=True)
    )
    sf.plot.plot_community(
        w0,
        col_linkage_func=lambda w: w0.metagenotype.linkage(optimal_ordering=True),
        row_linkage_func=lambda w: w.genotype.linkage(optimal_ordering=True),
    )
    sf.plot.plot_community(
        w1,
        col_linkage_func=lambda w: w0.metagenotype.linkage(optimal_ordering=True),
        row_linkage_func=lambda w: w.genotype.linkage(optimal_ordering=True),
    )

    fig, axs = plt.subplots(2, 2)
    ax = axs[0, 0]
    ax.scatter(w0.community.entropy(), w1.community.entropy())
    ax.plot([0, 2.5], [0, 2.5])

    ax = axs[0, 1]
    ax.scatter(
        squareform(_world0.unifrac_pdist()), squareform(_world0.metagenotype.pdist())
    )
    ax.scatter(
        squareform(_world1.unifrac_pdist()), squareform(_world1.metagenotype.pdist())
    )

    _mgtp_pdist = _world0.metagenotype.pdist()
    _unifrac_pdist0 = _world0.unifrac_pdist()
    _unifrac_pdist1 = _world1.unifrac_pdist()

    _world0_sample_corr = {}
    _world1_sample_corr = {}
    for sample in _world0.sample.values:
        _world0_sample_corr[sample] = sp.stats.spearmanr(
            _unifrac_pdist0.loc[sample], _mgtp_pdist.loc[sample]
        )[0]
        _world1_sample_corr[sample] = sp.stats.spearmanr(
            _unifrac_pdist1.loc[sample], _mgtp_pdist.loc[sample]
        )[0]

    sample_accuracy = pd.DataFrame(
        dict(world0=_world0_sample_corr, world1=_world1_sample_corr)
    )

    ax = axs[1, 0]
    ax.scatter("world0", "world1", data=sample_accuracy)
    ax.plot([0, 1], [0, 1])

### Zotu4

#### s__Escherichia coli_D (Zotu4 -> 102506) (1 of 3)

In [None]:
motu_taxonomy[motu_taxonomy.g__.str.contains("Escherichia")]

In [None]:
rotu_taxonomy[rotu_taxonomy.str.contains("Escherichia")].reset_index().values

In [None]:
_species_id = "102506"
_rotu_list = ["Zotu4"]

_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)

print(motu_taxonomy.loc[_species_id])

strains_in_subjects(
    _species_rabund=motu_rabund[_species_id],
    _species_rabundB=rotu_rabund[_rotu_list].sum(1),
    _world=_world,
    _meta=sample[lambda x: x.diet_or_media.isin(["EEN", "PostEEN"])],
    plt_mean_key=True,
    ncols=3,
)

for sample_type_list, (ax_width, ax_height) in [(['human'], (8, 4)), (['Fermenter', 'Fermenter_inoculum'], (15, 8)), (['mouse'], (30, 9))]:
    strains_in_subjects(
        _species_rabund=motu_rabund[_species_id],
        _species_rabundB=rotu_rabund[_rotu_list].sum(1),
        _world=_world,
        _meta=sample[lambda x: x.subject_id.isin(["A", "B", "H"]) & x.sample_type.isin(sample_type_list)].assign(
            label=lambda x: x.apply(_label_experiment_sample, axis=1)
        ),
        ncols=3,
        ax_width=ax_width,
    )

In [None]:
for i, (sample_type_list, (ax_width, ax_height)) in enumerate([(['human'], (8, 4)), (['Fermenter', 'Fermenter_inoculum'], (15, 8)), (['mouse'], (30, 9))]):
    strains_in_subjects(
        _species_rabund=motu_rabund[_species_id],
        _species_rabundB=rotu_rabund[_rotu_list].sum(1),
        _world=_world,
        _meta=sample[lambda x: x.subject_id.isin(["H"]) & x.sample_type.isin(sample_type_list)].assign(
            label=lambda x: x.apply(_label_experiment_sample, axis=1)
        ),
        ncols=1,
        ax_width=ax_width,
        # savefig=f'fig/{_species_id}.strain_tracking.{i}.pdf',
    )

In [None]:
for i, (sample_type_list, (ax_width, ax_height)) in enumerate([(['human'], (8, 4)), (['Fermenter', 'Fermenter_inoculum'], (15, 8)), (['mouse'], (30, 9))]):
    strains_in_subjects(
        _species_rabund=motu_rabund[_species_id],
        _species_rabundB=rotu_rabund[_rotu_list].sum(1),
        _world=_world,
        _meta=sample[lambda x: x.subject_id.isin(["A", "B", "H"]) & x.sample_type.isin(sample_type_list)].assign(
            label=lambda x: x.apply(_label_experiment_sample, axis=1)
        ),
        ncols=1,
        ax_width=ax_width,
        savefig=f'fig/{_species_id}.strain_tracking.{i}.pdf',
    )

#### s__Klebsiella pneumoniae (Zotu4 -> 102538) (2 of 3)

In [None]:
_species_id = "102538"
_rotu_list = ["Zotu4"]

_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)

print(motu_taxonomy.loc[_species_id])

strains_in_subjects(
    _species_rabund=motu_rabund[_species_id],
    _species_rabundB=rotu_rabund[_rotu_list].sum(1),
    _world=_world,
    _meta=sample[lambda x: x.diet_or_media.isin(["EEN", "PostEEN"])],
    plt_mean_key=True,
    ncols=3,
)

for sample_type_list, (ax_width, ax_height) in [(['human'], (8, 4)), (['Fermenter', 'Fermenter_inoculum'], (15, 5)), (['mouse'], (30, 7))]:
    strains_in_subjects(
        _species_rabund=motu_rabund[_species_id],
        _species_rabundB=rotu_rabund[_rotu_list].sum(1),
        _world=_world,
        _meta=sample[lambda x: x.subject_id.isin(["A", "B", "H"]) & x.sample_type.isin(sample_type_list)].assign(
            label=lambda x: x.apply(_label_experiment_sample, axis=1)
        ),
        ncols=3,
        ax_width=ax_width,
    )

#### s__Escherichia marmotae (Zotu4 -> 102322) (3 of 3)

In [None]:
_species_id = "102322"
_rotu_list = ["Zotu4"]

_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)

print(motu_taxonomy.loc[_species_id])

strains_in_subjects(
    _species_rabund=motu_rabund[_species_id],
    _species_rabundB=rotu_rabund[_rotu_list].sum(1),
    _world=_world,
    _meta=sample[lambda x: x.diet_or_media.isin(["EEN", "PostEEN"])],
    plt_mean_key=True,
    ncols=3,
)

for sample_type_list, (ax_width, ax_height) in [(['human'], (8, 4)), (['Fermenter', 'Fermenter_inoculum'], (15, 5)), (['mouse'], (30, 7))]:
    strains_in_subjects(
        _species_rabund=motu_rabund[_species_id],
        _species_rabundB=rotu_rabund[_rotu_list].sum(1),
        _world=_world,
        _meta=sample[lambda x: x.subject_id.isin(["A", "B", "H"]) & x.sample_type.isin(sample_type_list)].assign(
            label=lambda x: x.apply(_label_experiment_sample, axis=1)
        ),
        ncols=3,
        ax_width=ax_width,
    )

### s__Enterococcus_D sp002850555 (Zotu85 -> 100323)

In [None]:
_species_id = "100323"
_rotu_list = ["Zotu85"]

_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)

print(motu_taxonomy.loc[_species_id])

strains_in_subjects(
    _species_rabund=motu_rabund[_species_id],
    _species_rabundB=rotu_rabund[_rotu_list].sum(1),
    _world=_world,
    _meta=sample[lambda x: x.diet_or_media.isin(["EEN", "PostEEN"])],
    plt_mean_key=True,
    ncols=3,
)

for sample_type_list, (ax_width, ax_height) in [(['human'], (8, 4)), (['Fermenter', 'Fermenter_inoculum'], (15, 5)), (['mouse'], (30, 7))]:
    strains_in_subjects(
        _species_rabund=motu_rabund[_species_id],
        _species_rabundB=rotu_rabund[_rotu_list].sum(1),
        _world=_world,
        _meta=sample[lambda x: x.subject_id.isin(["A", "B", "H"]) & x.sample_type.isin(sample_type_list)].assign(
            label=lambda x: x.apply(_label_experiment_sample, axis=1)
        ),
        ncols=3,
        ax_width=ax_width,
    )

### s__Eggerthella lenta (Zotu172 -> 102544)

In [None]:
_species_id = "102544"
_rotu_list = ["Zotu172"]

_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)

print(motu_taxonomy.loc[_species_id])

strains_in_subjects(
    _species_rabund=motu_rabund[_species_id],
    _species_rabundB=rotu_rabund[_rotu_list].sum(1),
    _world=_world,
    _meta=sample[lambda x: x.diet_or_media.isin(["EEN", "PostEEN"])],
    plt_mean_key=True,
    ncols=3,
)

for sample_type_list, (ax_width, ax_height) in [(['human'], (8, 4)), (['Fermenter', 'Fermenter_inoculum'], (15, 5)), (['mouse'], (30, 7))]:
    strains_in_subjects(
        _species_rabund=motu_rabund[_species_id],
        _species_rabundB=rotu_rabund[_rotu_list].sum(1),
        _world=_world,
        _meta=sample[lambda x: x.subject_id.isin(["A", "B", "H"]) & x.sample_type.isin(sample_type_list)].assign(
            label=lambda x: x.apply(_label_experiment_sample, axis=1)
        ),
        ncols=3,
        ax_width=ax_width,
    )

### s__Dorea scindens (Zotu114 -> 101303)

In [None]:
_species_id = "101303"
_rotu_list = ["Zotu114"]

_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)

print(motu_taxonomy.loc[_species_id])

strains_in_subjects(
    _species_rabund=motu_rabund[_species_id],
    _species_rabundB=rotu_rabund[_rotu_list].sum(1),
    _world=_world,
    _meta=sample[lambda x: x.diet_or_media.isin(["EEN", "PostEEN"])],
    plt_mean_key=True,
    ncols=3,
)

for sample_type_list, (ax_width, ax_height) in [(['human'], (8, 4)), (['Fermenter', 'Fermenter_inoculum'], (15, 5)), (['mouse'], (30, 7))]:
    strains_in_subjects(
        _species_rabund=motu_rabund[_species_id],
        _species_rabundB=rotu_rabund[_rotu_list].sum(1),
        _world=_world,
        _meta=sample[lambda x: x.subject_id.isin(["A", "B", "H"]) & x.sample_type.isin(sample_type_list)].assign(
            label=lambda x: x.apply(_label_experiment_sample, axis=1)
        ),
        ncols=3,
        ax_width=ax_width,
    )

### s__Clostridium_Q symbiosum (Zotu25 -> 101367)

In [None]:
_species_id = "101367"
_rotu_list = ["Zotu25"]

_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)

print(motu_taxonomy.loc[_species_id])

strains_in_subjects(
    _species_rabund=motu_rabund[_species_id],
    _species_rabundB=rotu_rabund[_rotu_list].sum(1),
    _world=_world,
    _meta=sample[lambda x: x.diet_or_media.isin(["EEN", "PostEEN"])],
    plt_mean_key=True,
    ncols=3,
)

for sample_type_list, (ax_width, ax_height) in [(['human'], (8, 4)), (['Fermenter', 'Fermenter_inoculum'], (15, 5)), (['mouse'], (30, 7))]:
    strains_in_subjects(
        _species_rabund=motu_rabund[_species_id],
        _species_rabundB=rotu_rabund[_rotu_list].sum(1),
        _world=_world,
        _meta=sample[lambda x: x.subject_id.isin(["A", "B", "H"]) & x.sample_type.isin(sample_type_list)].assign(
            label=lambda x: x.apply(_label_experiment_sample, axis=1)
        ),
        ncols=3,
        ax_width=ax_width,
    )

### Zotu5

#### s__Clostridium_M clostridioforme (Zotu5 -> 101386) (1 of 2)

In [None]:
_species_id = "101386"
_rotu_list = ["Zotu5"]

_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)

print(motu_taxonomy.loc[_species_id])

strains_in_subjects(
    _species_rabund=motu_rabund[_species_id],
    _species_rabundB=rotu_rabund[_rotu_list].sum(1),
    _world=_world,
    _meta=sample[lambda x: x.diet_or_media.isin(["EEN", "PostEEN"])],
    plt_mean_key=True,
    ncols=3,
)

for sample_type_list, (ax_width, ax_height) in [(['human'], (8, 4)), (['Fermenter', 'Fermenter_inoculum'], (15, 5)), (['mouse'], (30, 7))]:
    strains_in_subjects(
        _species_rabund=motu_rabund[_species_id],
        _species_rabundB=rotu_rabund[_rotu_list].sum(1),
        _world=_world,
        _meta=sample[lambda x: x.subject_id.isin(["A", "B", "H"]) & x.sample_type.isin(sample_type_list)].assign(
            label=lambda x: x.apply(_label_experiment_sample, axis=1)
        ),
        ncols=3,
        ax_width=ax_width,
    )

#### s__Clostridium_M bolteae (Zotu5 -> 101493) (2 of 2)

In [None]:
_species_id = "101493"
_rotu_list = ["Zotu5"]

_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)

print(motu_taxonomy.loc[_species_id])

strains_in_subjects(
    _species_rabund=motu_rabund[_species_id],
    _species_rabundB=rotu_rabund[_rotu_list].sum(1),
    _world=_world,
    _meta=sample[lambda x: x.diet_or_media.isin(["EEN", "PostEEN"])],
    plt_mean_key=True,
    ncols=3,
)

for sample_type_list, (ax_width, ax_height) in [(['human'], (8, 4)), (['Fermenter', 'Fermenter_inoculum'], (15, 5)), (['mouse'], (30, 7))]:
    strains_in_subjects(
        _species_rabund=motu_rabund[_species_id],
        _species_rabundB=rotu_rabund[_rotu_list].sum(1),
        _world=_world,
        _meta=sample[lambda x: x.subject_id.isin(["A", "B", "H"]) & x.sample_type.isin(sample_type_list)].assign(
            label=lambda x: x.apply(_label_experiment_sample, axis=1)
        ),
        ncols=3,
        ax_width=ax_width,
    )

### s__Hungatella effluvii (Zotu10 -> 100032)

In [None]:
_species_id = "100032"
_rotu_list = ["Zotu10"]

_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)

print(motu_taxonomy.loc[_species_id])

strains_in_subjects(
    _species_rabund=motu_rabund[_species_id],
    _species_rabundB=rotu_rabund[_rotu_list].sum(1),
    _world=_world,
    _meta=sample[lambda x: x.diet_or_media.isin(["EEN", "PostEEN"])],
    plt_mean_key=True,
    ncols=3,
)

for sample_type_list, (ax_width, ax_height) in [(['human'], (8, 4)), (['Fermenter', 'Fermenter_inoculum'], (15, 5)), (['mouse'], (30, 7))]:
    strains_in_subjects(
        _species_rabund=motu_rabund[_species_id],
        _species_rabundB=rotu_rabund[_rotu_list].sum(1),
        _world=_world,
        _meta=sample[lambda x: x.subject_id.isin(["A", "B", "H"]) & x.sample_type.isin(sample_type_list)].assign(
            label=lambda x: x.apply(_label_experiment_sample, axis=1)
        ),
        ncols=3,
        ax_width=ax_width,
    )

### s__Hungatella hathewayi (Zotu160 -> 100150)

In [None]:
_species_id = "100150"
_rotu_list = ["Zotu160"]

_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)

print(motu_taxonomy.loc[_species_id])

strains_in_subjects(
    _species_rabund=motu_rabund[_species_id],
    _species_rabundB=rotu_rabund[_rotu_list].sum(1),
    _world=_world,
    _meta=sample[lambda x: x.diet_or_media.isin(["EEN", "PostEEN"])],
    plt_mean_key=True,
    ncols=3,
)

for sample_type_list, (ax_width, ax_height) in [(['human'], (8, 4)), (['Fermenter', 'Fermenter_inoculum'], (15, 5)), (['mouse'], (30, 7))]:
    strains_in_subjects(
        _species_rabund=motu_rabund[_species_id],
        _species_rabundB=rotu_rabund[_rotu_list].sum(1),
        _world=_world,
        _meta=sample[lambda x: x.subject_id.isin(["A", "B", "H"]) & x.sample_type.isin(sample_type_list)].assign(
            label=lambda x: x.apply(_label_experiment_sample, axis=1)
        ),
        ncols=3,
        ax_width=ax_width,
    )

### s__Clostridium_M sp-? (Zotu31 -> 100179)

In [None]:
_species_id = "100179"
_rotu_list = ["Zotu31"]

_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)

print(motu_taxonomy.loc[_species_id])

strains_in_subjects(
    _species_rabund=motu_rabund[_species_id],
    _species_rabundB=rotu_rabund[_rotu_list].sum(1),
    _world=_world,
    _meta=sample[lambda x: x.diet_or_media.isin(["EEN", "PostEEN"])],
    plt_mean_key=True,
    ncols=3,
)

for sample_type_list, (ax_width, ax_height) in [(['human'], (8, 4)), (['Fermenter', 'Fermenter_inoculum'], (15, 5)), (['mouse'], (30, 7))]:
    strains_in_subjects(
        _species_rabund=motu_rabund[_species_id],
        _species_rabundB=rotu_rabund[_rotu_list].sum(1),
        _world=_world,
        _meta=sample[lambda x: x.subject_id.isin(["A", "B", "H"]) & x.sample_type.isin(sample_type_list)].assign(
            label=lambda x: x.apply(_label_experiment_sample, axis=1)
        ),
        ncols=3,
        ax_width=ax_width,
    )

### s__Clostridium_M sp000431375 (Zotu155 -> 100242)

In [None]:
_species_id = "100242"
_rotu_list = ["Zotu155"]

_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)

print(motu_taxonomy.loc[_species_id])

strains_in_subjects(
    _species_rabund=motu_rabund[_species_id],
    _species_rabundB=rotu_rabund[_rotu_list].sum(1),
    _world=_world,
    _meta=sample[lambda x: x.diet_or_media.isin(["EEN", "PostEEN"])],
    plt_mean_key=True,
    ncols=3,
)

for sample_type_list, (ax_width, ax_height) in [(['human'], (8, 4)), (['Fermenter', 'Fermenter_inoculum'], (15, 5)), (['mouse'], (30, 7))]:
    strains_in_subjects(
        _species_rabund=motu_rabund[_species_id],
        _species_rabundB=rotu_rabund[_rotu_list].sum(1),
        _world=_world,
        _meta=sample[lambda x: x.subject_id.isin(["A", "B", "H"]) & x.sample_type.isin(sample_type_list)].assign(
            label=lambda x: x.apply(_label_experiment_sample, axis=1)
        ),
        ncols=3,
        ax_width=ax_width,
    )

### s__Erysipelatoclostridium ramosum (Zotu46 -> 101400)

In [None]:
_species_id = "101400"
_rotu_list = ["Zotu46"]

_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)

print(motu_taxonomy.loc[_species_id])

strains_in_subjects(
    _species_rabund=motu_rabund[_species_id],
    _species_rabundB=rotu_rabund[_rotu_list].sum(1),
    _world=_world,
    _meta=sample[lambda x: x.diet_or_media.isin(["EEN", "PostEEN"])],
    plt_mean_key=True,
    ncols=3,
)

for sample_type_list, (ax_width, ax_height) in [(['human'], (8, 4)), (['Fermenter', 'Fermenter_inoculum'], (15, 5)), (['mouse'], (30, 7))]:
    strains_in_subjects(
        _species_rabund=motu_rabund[_species_id],
        _species_rabundB=rotu_rabund[_rotu_list].sum(1),
        _world=_world,
        _meta=sample[lambda x: x.subject_id.isin(["A", "B", "H"]) & x.sample_type.isin(sample_type_list)].assign(
            label=lambda x: x.apply(_label_experiment_sample, axis=1)
        ),
        ncols=3,
        ax_width=ax_width,
    )

### s__Flavonifractor plautii (Zotu49 -> 100099)

In [None]:
_species_id = "100099"
_rotu_list = ["Zotu46"]

_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)

print(motu_taxonomy.loc[_species_id])

strains_in_subjects(
    _species_rabund=motu_rabund[_species_id],
    _species_rabundB=rotu_rabund[_rotu_list].sum(1),
    _world=_world,
    _meta=sample[lambda x: x.diet_or_media.isin(["EEN", "PostEEN"])],
    plt_mean_key=True,
    ncols=3,
)

for sample_type_list, (ax_width, ax_height) in [(['human'], (8, 4)), (['Fermenter', 'Fermenter_inoculum'], (15, 5)), (['mouse'], (30, 7))]:
    strains_in_subjects(
        _species_rabund=motu_rabund[_species_id],
        _species_rabundB=rotu_rabund[_rotu_list].sum(1),
        _world=_world,
        _meta=sample[lambda x: x.subject_id.isin(["A", "B", "H"]) & x.sample_type.isin(sample_type_list)].assign(
            label=lambda x: x.apply(_label_experiment_sample, axis=1)
        ),
        ncols=3,
        ax_width=ax_width,
    )

### s__Eisenbergiella tayi (Zotu75 -> 102330)

In [None]:
_species_id = "102330"
_rotu_list = ["Zotu75"]

_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)

print(motu_taxonomy.loc[_species_id])

strains_in_subjects(
    _species_rabund=motu_rabund[_species_id],
    _species_rabundB=rotu_rabund[_rotu_list].sum(1),
    _world=_world,
    _meta=sample[lambda x: x.diet_or_media.isin(["EEN", "PostEEN"])],
    plt_mean_key=True,
    ncols=3,
)

for sample_type_list, (ax_width, ax_height) in [(['human'], (8, 4)), (['Fermenter', 'Fermenter_inoculum'], (15, 5)), (['mouse'], (30, 7))]:
    strains_in_subjects(
        _species_rabund=motu_rabund[_species_id],
        _species_rabundB=rotu_rabund[_rotu_list].sum(1),
        _world=_world,
        _meta=sample[lambda x: x.subject_id.isin(["A", "B", "H"]) & x.sample_type.isin(sample_type_list)].assign(
            label=lambda x: x.apply(_label_experiment_sample, axis=1)
        ),
        ncols=3,
        ax_width=ax_width,
    )

### s__Ruthenibacterium lactatiformans (Zotu78	103682)

In [None]:
_species_id = "103682"
_rotu_list = ["Zotu78"]

_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)

print(motu_taxonomy.loc[_species_id])

strains_in_subjects(
    _species_rabund=motu_rabund[_species_id],
    _species_rabundB=rotu_rabund[_rotu_list].sum(1),
    _world=_world,
    _meta=sample[lambda x: x.diet_or_media.isin(["EEN", "PostEEN"])],
    plt_mean_key=True,
    ncols=3,
)

for sample_type_list, (ax_width, ax_height) in [(['human'], (8, 4)), (['Fermenter', 'Fermenter_inoculum'], (15, 5)), (['mouse'], (30, 7))]:
    strains_in_subjects(
        _species_rabund=motu_rabund[_species_id],
        _species_rabundB=rotu_rabund[_rotu_list].sum(1),
        _world=_world,
        _meta=sample[lambda x: x.subject_id.isin(["A", "B", "H"]) & x.sample_type.isin(sample_type_list)].assign(
            label=lambda x: x.apply(_label_experiment_sample, axis=1)
        ),
        ncols=3,
        ax_width=ax_width,
    )

### s__Blautia_A wexlerae (Zotu13 -> 101338)

In [None]:
_species_id = "101338"
_rotu_list = ["Zotu13"]

_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)

print(motu_taxonomy.loc[_species_id])

strains_in_subjects(
    _species_rabund=motu_rabund[_species_id],
    _species_rabundB=rotu_rabund[_rotu_list].sum(1),
    _world=_world,
    _meta=sample[lambda x: x.diet_or_media.isin(["EEN", "PostEEN"])],
    plt_mean_key=True,
    ncols=3,
)

for sample_type_list, (ax_width, ax_height) in [(['human'], (8, 4)), (['Fermenter', 'Fermenter_inoculum'], (15, 5)), (['mouse'], (30, 7))]:
    strains_in_subjects(
        _species_rabund=motu_rabund[_species_id],
        _species_rabundB=rotu_rabund[_rotu_list].sum(1),
        _world=_world,
        _meta=sample[lambda x: x.subject_id.isin(["A", "B", "H"]) & x.sample_type.isin(sample_type_list)].assign(
            label=lambda x: x.apply(_label_experiment_sample, axis=1)
        ),
        ncols=3,
        ax_width=ax_width,
    )

### s__Bacteroides caccae (Zotu74 -> 102549)

In [None]:
_species_id = "102549"
_rotu_list = ["Zotu74"]

_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)

print(motu_taxonomy.loc[_species_id])

strains_in_subjects(
    _species_rabund=motu_rabund[_species_id],
    _species_rabundB=rotu_rabund[_rotu_list].sum(1),
    _world=_world,
    _meta=sample[lambda x: x.diet_or_media.isin(["EEN", "PostEEN"])],
    plt_mean_key=True,
    ncols=3,
)

for sample_type_list, (ax_width, ax_height) in [(['human'], (8, 4)), (['Fermenter', 'Fermenter_inoculum'], (15, 5)), (['mouse'], (30, 7))]:
    strains_in_subjects(
        _species_rabund=motu_rabund[_species_id],
        _species_rabundB=rotu_rabund[_rotu_list].sum(1),
        _world=_world,
        _meta=sample[lambda x: x.subject_id.isin(["A", "B", "H"]) & x.sample_type.isin(sample_type_list)].assign(
            label=lambda x: x.apply(_label_experiment_sample, axis=1)
        ),
        ncols=3,
        ax_width=ax_width,
    )

### s__Bacteroides dorei (Zotu1 -> 102478)

In [None]:
_species_id = "102478"
_rotu_list = ["Zotu1"]

_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)

print(motu_taxonomy.loc[_species_id])

strains_in_subjects(
    _species_rabund=motu_rabund[_species_id],
    _species_rabundB=rotu_rabund[_rotu_list].sum(1),
    _world=_world,
    _meta=sample[lambda x: x.diet_or_media.isin(["EEN", "PostEEN"])],
    plt_mean_key=True,
    ncols=3,
)

for sample_type_list, (ax_width, ax_height) in [(['human'], (8, 4)), (['Fermenter', 'Fermenter_inoculum'], (15, 5)), (['mouse'], (30, 7))]:
    strains_in_subjects(
        _species_rabund=motu_rabund[_species_id],
        _species_rabundB=rotu_rabund[_rotu_list].sum(1),
        _world=_world,
        _meta=sample[lambda x: x.subject_id.isin(["A", "B", "H"]) & x.sample_type.isin(sample_type_list)].assign(
            label=lambda x: x.apply(_label_experiment_sample, axis=1)
        ),
        ncols=3,
        ax_width=ax_width,
    )

### s__Bacteroides uniformis (Zotu6 -> 101346)

In [None]:
_species_id = "101346"
_rotu_list = ["Zotu6"]

_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)

print(motu_taxonomy.loc[_species_id])

strains_in_subjects(
    _species_rabund=motu_rabund[_species_id],
    _species_rabundB=rotu_rabund[_rotu_list].sum(1),
    _world=_world,
    _meta=sample[lambda x: x.diet_or_media.isin(["EEN", "PostEEN"])],
    plt_mean_key=True,
    ncols=3,
)

for sample_type_list, (ax_width, ax_height) in [(['human'], (8, 4)), (['Fermenter', 'Fermenter_inoculum'], (15, 5)), (['mouse'], (30, 7))]:
    strains_in_subjects(
        _species_rabund=motu_rabund[_species_id],
        _species_rabundB=rotu_rabund[_rotu_list].sum(1),
        _world=_world,
        _meta=sample[lambda x: x.subject_id.isin(["A", "B", "H"]) & x.sample_type.isin(sample_type_list)].assign(
            label=lambda x: x.apply(_label_experiment_sample, axis=1)
        ),
        ncols=3,
        ax_width=ax_width,
    )

### s__Bacteroides ovatus (Zotu1 -> 101378 -> Zotu14+Zotu23)

In [None]:
_species_id = "101378"
_rotu_list = ["Zotu14", "Zotu23"]

_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)

print(motu_taxonomy.loc[_species_id])

strains_in_subjects(
    _species_rabund=motu_rabund[_species_id],
    _species_rabundB=rotu_rabund[_rotu_list].sum(1),
    _world=_world,
    _meta=sample[lambda x: x.diet_or_media.isin(["EEN", "PostEEN"])],
    plt_mean_key=True,
    ncols=3,
)

for sample_type_list, (ax_width, ax_height) in [(['human'], (8, 4)), (['Fermenter', 'Fermenter_inoculum'], (15, 5)), (['mouse'], (30, 7))]:
    strains_in_subjects(
        _species_rabund=motu_rabund[_species_id],
        _species_rabundB=rotu_rabund[_rotu_list].sum(1),
        _world=_world,
        _meta=sample[lambda x: x.subject_id.isin(["A", "B", "H"]) & x.sample_type.isin(sample_type_list)].assign(
            label=lambda x: x.apply(_label_experiment_sample, axis=1)
        ),
        ncols=3,
        ax_width=ax_width,
    )

## Permutation Test Prototype #3

This permutation test differs from earlier versions in that
- it permutes the pair labels,
- it fits a cubic spline (with 4 knots) to the time data
- it's run on the compositional data (instead of metagenomics)

In [None]:
def turnover_analysis3(
    _rabund,
    _meta,
    _rabund_ctrl=None,  # e.g. control for species-level turnover
    pair_type_palette=pair_type_palette,
    pair_type_marker_palette=pair_type_marker_palette,
    pair_type_linestyle_palette=pair_type_linestyle_palette,
    subject_palette=subject_palette,
    subject_order=subject_order,
    pair_type_order=pair_type_order,
    _dmat=None,
    n_perm=999,
):
    # Select data
    _rabund, _meta = lib.pandas_util.align_indexes(_rabund, _meta)

    # Calculate pairwise comparisons
    if _dmat is None:
        _dmat = pdist(_rabund, metric="braycurtis")
        if _rabund_ctrl is not None:
            # NOTE: bc stands for bray-curtis and is used as a stand-in for
            # dissimilarity throughout because
            # the default is to use bc dissimilarity.
            bc_cdist_ctrl = pdist(_rabund_ctrl.loc[_rabund.index], metric='braycurtis')
            _dmat = _dmat - bc_cdist_ctrl
    else:
        assert _rabund_ctrl is None
        assert _dmat.shape[0] == _rabund.shape[0]
        assert _dmat.shape[1] == _rabund.shape[0]
        _dmat = squareform(_dmat)
        # But it's the users responsibility to sort the dmat correctly.

    time_cdist = pdist(
        _meta[["collection_date_relative_een_end"]], metric=lambda x, y: np.abs(x - y)
    )
    depth_min_cdist = pdist(_meta[['depth']], metric=lambda x, y: min(x, y))
    diff_subject_cdist = pdist(_meta[["subject_id"]], metric=lambda x, y: x != y)
    same_subject_cdist = (1 - diff_subject_cdist).astype(bool)
    type_transition_indicator = pdist(
        _meta[["diet_or_media"]],
        metric=lambda x, y: (x == "PostEEN").astype(float)
        + (y == "PostEEN").astype(float),
    )
    # same_type_cdist = (1 - diff_type_cdist).astype(bool)
    pairs = (
        pd.DataFrame(
            squareform(diff_subject_cdist),
            index=_meta.index,
            columns=_meta.index,
        )
        .rename_axis(index="sampleA", columns="sampleB")
        .unstack()
        .index.to_frame()
        .assign(pair=lambda x: x[["sampleA", "sampleB"]].apply(tuple, axis=1))
        .pair.unstack()
    )
    pairs = pd.Series(pairs.values[np.triu_indices_from(pairs.values, k=1)])
    d0 = pd.DataFrame(
        dict(
            sampleA=pairs.str[0],
            sampleB=pairs.str[1],
            bc=_dmat,  # Really this should be generic for any dissimilarity, not only BC.
            same_subject=same_subject_cdist,
            type_transition_indicator=type_transition_indicator,
            diff_type=type_transition_indicator == 1,
            time_delta=time_cdist,
            depth_min=depth_min_cdist,
        )
    ).assign(
        pair_type=lambda x: x.type_transition_indicator.map(
            {0: "EEN", 1: "Transition", 2: "PostEEN"}
        ),
        subject_id=lambda x: x.sampleA.map(_meta.subject_id),
        subject_id_other=lambda x: x.sampleB.map(_meta.subject_id),
    )
    d1 = d0[lambda x: x.same_subject]

    # Plot and observed relationship
    fig, axs = plt.subplots(1, 2)
    ax = axs[0]
    sns.boxenplot(
        y="bc",
        x="same_subject",
        hue="pair_type",
        data=d0,
        ax=ax,
        palette=pair_type_palette,
        order=[False, True],
        hue_order=pair_type_order,
    )
    # ax.legend()
    ax = axs[1]
    sns.stripplot(
        y="bc",
        x="pair_type",
        hue="subject_id",
        data=d1,
        ax=ax,
        palette=subject_palette,
        order=pair_type_order,
    )
    ax.legend_.set_visible(False)

    # Fit observed relationship
    # formula = "bc ~ 0 + C(pair_type) + cr(time_delta, 4) + C(subject_id, Sum) + cr(depth_min, 4)"
    formula = "bc ~ 0 + C(pair_type) + cr(time_delta, 4) + C(subject_id, Sum)"
    fit = smf.ols(formula, data=d1).fit()
    observed_stat = fit.params
    observed_stat["PostEEN - EEN"] = (
        observed_stat["C(pair_type)[PostEEN]"] - observed_stat["C(pair_type)[EEN]"]
    )
    observed_stat["Transition - EEN"] = (
        observed_stat["C(pair_type)[Transition]"] - observed_stat["C(pair_type)[EEN]"]
    )
    observed_stat["Transition - PostEEN"] = (
        observed_stat["C(pair_type)[Transition]"]
        - observed_stat["C(pair_type)[PostEEN]"]
    )
    print(observed_stat[["Transition - EEN", "Transition - PostEEN", "PostEEN - EEN"]])

    # Calculate permutations
    np.random.seed(1)
    perm_stat = {}
    for i in tqdm(range(n_perm)):
        # Permute sample-pair labels within subjects
        _perm_type = (
            d0.assign(mgen_id=lambda x: x.index)
            .groupby("subject_id")
            .pair_type.transform(np.random.permutation)
        )
        perm_fit = smf.ols(
            formula,
            data=d0.assign(
                pair_type=_perm_type,
            )[lambda x: x.same_subject],
        ).fit()
        _stat = perm_fit.params
        _stat["PostEEN - EEN"] = (
            _stat["C(pair_type)[PostEEN]"] - _stat["C(pair_type)[EEN]"]
        )
        _stat["Transition - EEN"] = (
            _stat["C(pair_type)[Transition]"] - _stat["C(pair_type)[EEN]"]
        )
        _stat["Transition - PostEEN"] = (
            _stat["C(pair_type)[Transition]"] - _stat["C(pair_type)[PostEEN]"]
        )
        perm_stat[i] = _stat
    perm_stat = pd.DataFrame(perm_stat).T

    perm_pvalues = _calculate_2tailed_pvalue_from_perm(observed_stat, perm_stat)

    fig, axs = plt.subplots(3)
    # Plot permutation tests
    ax = axs[0]
    param_name = "Transition - EEN"
    ax.set_title(param_name)
    ax.hist(perm_stat[param_name], bins=20)
    ax.axvline(observed_stat[param_name])
    ax.annotate(
        perm_pvalues[param_name],
        xy=(0.95, 0.95),
        xycoords="axes fraction",
        ha="right",
        va="top",
    )
    ax = axs[1]
    param_name = "Transition - PostEEN"
    ax.set_title(param_name)
    ax.hist(perm_stat[param_name], bins=20)
    ax.axvline(observed_stat[param_name])
    hypoth = perm_stat[param_name] > observed_stat[param_name]
    null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
    ax.annotate(
        perm_pvalues[param_name],
        xy=(0.95, 0.95),
        xycoords="axes fraction",
        ha="right",
        va="top",
    )
    ax = axs[2]
    param_name = "PostEEN - EEN"
    ax.set_title(param_name)
    ax.hist(perm_stat[param_name], bins=20)
    ax.axvline(observed_stat[param_name])
    hypoth = perm_stat[param_name] > observed_stat[param_name]
    null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
    ax.annotate(
        perm_pvalues[param_name],
        xy=(0.95, 0.95),
        xycoords="axes fraction",
        ha="right",
        va="top",
    )
    fig.tight_layout()

    d2 = d1.assign(
        # predict=fit.predict(),
        # predict_mean_subject=fit.predict(d1.assign(subject_id=_arbitrary_subject))
        # - fit.params[f"C(subject_id, Sum)[S.{_arbitrary_subject}]"],
        resid_pearson=fit.resid_pearson,
        influence=fit.get_influence().summary_frame().cooks_d,
    ).sort_values("time_delta")

    fig, ax = plt.subplots()
    art = ax.scatter(
        "time_delta",
        "resid_pearson",
        c="influence",
        data=d2,
    )
    fig.colorbar(art, label="Cook's D")
    # ax.set_ylabel("Residual BC\n(standardized)")
    ax.set_xlabel("Within-Subjects Days between Samples")
    ax.set_xscale("symlog")

    fit.summary()

    fig, ax = plt.subplots()
    ax.set_title("Within-subject Pairwise Turnover")
    for pair_type in pair_type_order:
        d3 = d1[lambda x: (x.pair_type == pair_type)]
        ax.scatter(
            "time_delta",
            "bc",
            label="__nolegend__",
            color=pair_type_palette[pair_type],
            data=d3,
            marker=pair_type_marker_palette[pair_type],
            edgecolor='grey', lw=0.5,
        )

    _arbitrary_subject = d1.subject_id.unique()[1]
    predict_data = pd.DataFrame(
        product(
            [_arbitrary_subject],
            ["EEN", "PostEEN", "Transition"],
            np.logspace(1.0, 2.6),
            [1.0]
        ),
        columns=["subject_id", "pair_type", "time_delta", "depth_min"],
    )
    predict_data = predict_data.assign(
        prediction=fit.predict(predict_data),
        predict_mean_subject=lambda x: x.prediction
        - fit.params[f"C(subject_id, Sum)[S.{_arbitrary_subject}]"],
    )
    for pair_type in pair_type_order:
        d4 = predict_data[lambda x: x.pair_type == pair_type]
        left, right = d1[lambda x: x.pair_type == pair_type].time_delta.quantile(
            [0.05, 0.95]
        )
        ax.plot(
            "time_delta",
            "predict_mean_subject",
            label="__nolegend__",
            data=d4[lambda x: (x.time_delta > left) & (x.time_delta < right)],
            color='black',
            linestyle=pair_type_linestyle_palette[pair_type],
        )
    ax.set_ylabel("Pairwise Dissimilarity")
    ax.set_xlabel("Days between Samples")
    ax.set_xscale("symlog", linthresh=1e-1)

    # fig, ax = plt.subplots(figsize=(2.5, 1.5))
    for pair_type in pair_type_order:
        ax.plot(
            [],
            [],
            label=pair_type,
            color=pair_type_palette[pair_type],
            marker=pair_type_marker_palette[pair_type],
            linestyle=pair_type_linestyle_palette[pair_type],
        )
    ax.legend()

    d2 = d1.assign(
        predict=fit.predict(),
        predict_mean_subject=fit.predict(d1.assign(subject_id=_arbitrary_subject))
        - fit.params[f"C(subject_id, Sum)[S.{_arbitrary_subject}]"],
        resid_pearson=fit.resid_pearson,
        influence=fit.get_influence().summary_frame().cooks_d,
    ).sort_values("time_delta")

    fig, ax = plt.subplots()
    ax.set_title("Within-subject Pairwise Turnover")
    for subject_id, pair_type in product(subject_order, pair_type_order):
        d3 = d2[lambda x: (x.subject_id == subject_id) & (x.pair_type == pair_type)]
        ax.scatter(
            "time_delta",
            "bc",
            label="__nolegend__",
            color=subject_palette[subject_id],
            data=d3,
            marker=pair_type_marker_palette[pair_type],
            alpha=0.7,
        )

    predict_data = pd.DataFrame(
        product(
            [_arbitrary_subject],
            ["EEN", "PostEEN", "Transition"],
            np.logspace(1.1, 2.6),
            [1.0],
        ),
        columns=["subject_id", "pair_type", "time_delta", "depth_min"],
    )
    predict_data = predict_data.assign(
        prediction=fit.predict(predict_data),
        predict_mean_subject=lambda x: x.prediction
        - fit.params[f"C(subject_id, Sum)[S.{_arbitrary_subject}]"],
    )

    for pair_type in pair_type_order:
        d4 = predict_data[lambda x: x.pair_type == pair_type]
        left, right = d1[lambda x: x.pair_type == pair_type].time_delta.quantile(
            [0.05, 0.95]
        )
        ax.plot(
            "time_delta",
            "predict_mean_subject",
            label="__nolegend__",
            data=d4[lambda x: (x.time_delta > left) & (x.time_delta < right)],
            color="black",
            linestyle=pair_type_linestyle_palette[pair_type],
        )
    ax.set_ylabel("Bray-Curtis Dissimilarity")
    ax.set_xlabel("Days between Samples")
    ax.set_xscale("symlog", linthresh=1e-1)

    fig, ax = plt.subplots(figsize=(5, 2))
    # Legends
    for subject_id in subject_order:
        ax.scatter(
            [], [], label=subject_id, color=subject_palette[subject_id], marker="s"
        )
    ax.legend(ncols=5, title="Subject")

    fig, ax = plt.subplots(figsize=(2.5, 1.5))
    for pair_type in pair_type_order:
        ax.plot(
            [],
            [],
            label=pair_type,
            color="black",
            marker=pair_type_marker_palette[pair_type],
            linestyle=pair_type_linestyle_palette[pair_type],
        )
    ax.legend()
    return fit.summary()

### Metagenomic Species Turnover

In [None]:
turnover_analysis3(
    _rabund=motu_rabund,
    _meta=sample[
        sample.diet_or_media.isin(["EEN", "PostEEN"])
        & ~sample.collection_date_relative_een_end.isna()
    ].assign(depth=0),
)

### zOTU Species Turnover

In [None]:
turnover_analysis3(
    _rabund=rotu_rabund,
    _meta=sample[
        sample.diet_or_media.isin(["EEN", "PostEEN"])
        & ~sample.collection_date_relative_een_end.isna()
    ].assign(depth=0),
)

### zOTU Species Turnover using Generalized Unifrac

In [None]:
dmat_gunifrac = pd.read_excel('raw/een-mgen/2023-09-28_deborah.haecker@tum-create.edu.sg/distance-matrix-gunif_Byron.xlsx', index_col='ID')

In [None]:
_rabund = rotu_rabund
_meta=sample[
        sample.diet_or_media.isin(["EEN", "PostEEN"])
        & ~sample.collection_date_relative_een_end.isna()
    ].assign(depth=0)

_rabund, _meta = lib.pandas_util.align_indexes(_rabund, _meta)

_dmat = dmat_gunifrac.loc[_meta.index, _meta.index]

turnover_analysis3(
    _rabund=_rabund,
    _meta=_meta,
    _dmat=_dmat,
    n_perm=999,
)

### zOTU Species Turnover (but same samples as Metagenomic Species)

In [None]:
turnover_analysis3(
    _rabund=rotu_rabund.loc[motu_rabund.index],
    _meta=sample[
        sample.diet_or_media.isin(["EEN", "PostEEN"])
        & ~sample.collection_date_relative_een_end.isna()
    ].assign(depth=0),
)

### zOTU Species Turnover using Generalized Unifrac (but same samples as Metagenomic Species)

In [None]:
_rabund = rotu_rabund.loc[motu_rabund.index]
_meta=sample[
        sample.diet_or_media.isin(["EEN", "PostEEN"])
        & ~sample.collection_date_relative_een_end.isna()
    ].assign(depth=0)

_rabund, _meta = lib.pandas_util.align_indexes(_rabund, _meta)

_dmat = dmat_gunifrac.loc[_meta.index, _meta.index]

turnover_analysis3(
    _rabund=_rabund,
    _meta=_meta,
    _dmat=_dmat,
)

### Strain Turnover

In [None]:
turnover_analysis3(
    _rabund=sotu_rabund,
    _meta=sample[
        sample.diet_or_media.isin(["EEN", "PostEEN"])
        & ~sample.collection_date_relative_een_end.isna()
    ].assign(depth=0),
)

### Strain Turnover while Controlling for Species

In [None]:
turnover_analysis3(
    _rabund=sotu_rabund,
    _rabund_ctrl=motu_rabund,
    _meta=sample[
        sample.diet_or_media.isin(["EEN", "PostEEN"])
        & ~sample.collection_date_relative_een_end.isna()
    ].assign(depth=0),
)

### Individual strain turnover

#### g__Escherichia coli_D

In [None]:
_species_id = "102506"
_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)
_rabund = _world.community.sel(sample=_world.metagenotype.mean_depth() > 1e-1).to_pandas()
_meta=sample[
        sample.diet_or_media.isin(["EEN", "PostEEN"])
        & ~sample.collection_date_relative_een_end.isna()
    ].assign(depth=0)

turnover_analysis3(
    _rabund=_rabund,
    _meta=_meta,
)

In [None]:
_species_id = "102506"
_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)
_rabund = _world.community.sel(sample=_world.metagenotype.mean_depth() > 1e-1).to_pandas()
_meta=sample[
        sample.diet_or_media.isin(["EEN", "PostEEN"])
        & ~sample.collection_date_relative_een_end.isna()
    ].assign(depth=0)

_rabund, _meta = lib.pandas_util.align_indexes(_rabund, _meta)

_dmat = _world.unifrac_pdist(discretized=False).loc[_meta.index, _meta.index]

turnover_analysis3(
    _rabund=_rabund,
    _meta=_meta,
    _dmat=_dmat,
)

#### s__Eggerthella lenta

In [None]:
_species_id = "102544"
_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)
_rabund = _world.community.sel(sample=_world.metagenotype.mean_depth() > 1e-1).to_pandas()
_meta = sample[
        sample.diet_or_media.isin(["EEN", "PostEEN"])
        & ~sample.collection_date_relative_een_end.isna()
    ].assign(depth=0)

turnover_analysis3(
    _rabund=_rabund,
    _meta=_meta,
)

In [None]:
_species_id = "102544"
_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)
_rabund = _world.community.sel(sample=_world.metagenotype.mean_depth() > 1e-1).to_pandas()
_meta=sample[
        sample.diet_or_media.isin(["EEN", "PostEEN"])
        & ~sample.collection_date_relative_een_end.isna()
    ].assign(depth=0)

_rabund, _meta = lib.pandas_util.align_indexes(_rabund, _meta)

_dmat = _world.unifrac_pdist(discretized=False).loc[_meta.index, _meta.index]

turnover_analysis3(
    _rabund=_rabund,
    _meta=_meta,
    _dmat=_dmat,
)

#### s__Flavonifractor plautii

In [None]:
_species_id = "100099"
_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)
_rabund = _world.community.sel(sample=_world.metagenotype.mean_depth() > 1e-1).to_pandas()
_meta = sample[
        sample.diet_or_media.isin(["EEN", "PostEEN"])
        & ~sample.collection_date_relative_een_end.isna()
    ].assign(depth=0)

turnover_analysis3(
    _rabund=_rabund,
    _meta=_meta,
)

In [None]:
_species_id = "102544"
_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)
_rabund = _world.community.sel(sample=_world.metagenotype.mean_depth() > 1e-1).to_pandas()
_meta=sample[
        sample.diet_or_media.isin(["EEN", "PostEEN"])
        & ~sample.collection_date_relative_een_end.isna()
    ].assign(depth=0)

_rabund, _meta = lib.pandas_util.align_indexes(_rabund, _meta)

_dmat = _world.unifrac_pdist(discretized=False).loc[_meta.index, _meta.index]

turnover_analysis3(
    _rabund=_rabund,
    _meta=_meta,
    _dmat=_dmat,
)

#### s__Bacteroides uniformis

In [None]:
_species_id = "101346"
_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)
_rabund = _world.community.sel(sample=_world.metagenotype.mean_depth() > 1e-1).to_pandas()
_meta = sample[
        sample.diet_or_media.isin(["EEN", "PostEEN"])
        & ~sample.collection_date_relative_een_end.isna()
    ].assign(depth=0)

turnover_analysis3(
    _rabund=_rabund,
    _meta=_meta,
)

#### TODO1

In [None]:
_species_id = "101378"
_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)
_rabund = _world.community.sel(sample=_world.metagenotype.mean_depth() > 1e-1).to_pandas()
_meta = sample[
        sample.diet_or_media.isin(["EEN", "PostEEN"])
        & ~sample.collection_date_relative_een_end.isna()
    ].assign(depth=0)

turnover_analysis3(
    _rabund=_rabund,
    _meta=_meta,
)