## Preamble

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

### Imports

In [None]:
import os
import subprocess
import time
from itertools import chain, product
from tempfile import mkstemp
from warnings import filterwarnings

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import sfacts as sf
import statsmodels.formula.api as smf
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable

# from fastcluster import linkage
from scipy.cluster.hierarchy import linkage
from scipy.spatial.distance import pdist, squareform
from statsmodels.graphics.regressionplots import influence_plot
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
import lib.thisproject.data
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

In [None]:
sns.set_context("paper")
plt.rcParams["figure.dpi"] = 200

In [None]:
def _calculate_2tailed_pvalue_from_perm(obs, perms):
    hypoth_left = perms > obs
    hypoth_right = perms < obs
    null_p_left = (hypoth_left.sum() + 1) / (len(hypoth_left) + 1)
    null_p_right = (hypoth_right.sum() + 1) / (len(hypoth_right) + 1)
    return np.minimum(null_p_left, null_p_right) * 2

In [None]:
def linkage_order(linkage, labels):
    return labels[sp.cluster.hierarchy.to_tree(linkage).pre_order(lambda x: x.id)]


def is_prime(n):
    if n <= 1:
        return False
    for i in range(2, int(n**0.5) + 1):
        if n % i == 0:
            return False
    return True


def iterate_primes_up_to(n, return_index=False):
    n = int(np.ceil(n))
    idx = 0
    for i in range(n):
        if is_prime(i):
            if return_index:
                yield (idx, i)
            else:
                yield i
            idx += 1


def maximally_shuffled_order(sorted_order):
    n = len(sorted_order)
    primes_list = list(iterate_primes_up_to(np.sqrt(n)))
    table = pd.DataFrame(np.arange(n), index=sorted_order, columns=["original_order"])
    for prime in primes_list:
        table[prime] = table.original_order % prime
    table.sort_values(primes_list).original_order.values
    table = table.assign(new_order=table.sort_values(primes_list).original_order.values)
    z = table.sort_values("new_order").original_order.values
    table["delta"] = [np.nan] + list(z[1:] - z[:-1])
    return table.sort_values("new_order").index.to_list()

## Construct Metadata

In [None]:
pair_type_palette={'Transition': 'plum', 'EEN': 'pink', 'PostEEN': 'lightblue'}

diet_palette = {
    "EEN": "lightgreen",
    "PostEEN": "lightblue",
    "InVitro": "plum",
    "PreEEN": "lightpink",
}

subject_order = [
    "A",
    "B",
    "H",
    "C",
    "D",
    "E",
    "F",
    "G",
    "K",
    "L",
    "M",
    "N",
    "O",
    "P",
    "Q",
    "R",
    "S",
    "T",
    "U",
]

# NOTE: Requires a dummy value because I want exactly 20 items.
subject_palette = lib.plot.construct_ordered_palette(
    subject_order + [f"dummy{i}" for i in range(20 - len(subject_order))], cm="tab20"
)
subject_palette["X"] = "black"
pair_type_order = ["EEN", "Transition", "PostEEN"]
pair_type_marker_palette = {"EEN": "s", "Transition": ">", "PostEEN": "o"}
pair_type_linestyle_palette = {"EEN": ":", "Transition": "-.", "PostEEN": "-"}

In [None]:
sample = (
    pd.read_table("meta/een-mgen/sample.tsv")
    .assign(
        label=lambda x: x[
            ["collection_date_relative_een_end", "diet_or_media", "sample_id"]
        ].apply(tuple, axis=1)
    )
    .set_index("sample_id")
)
subject = pd.read_table("meta/een-mgen/subject.tsv", index_col="subject_id")

In [None]:
rotu_counts = pd.read_table(
    "data/group/een/a.proc.zotu_counts.tsv", index_col="#OTU ID"
).rename_axis(index="zotu", columns="sample_id")
rotu_taxonomy = rotu_counts.taxonomy
rotu_counts = rotu_counts.drop(columns=["taxonomy"]).T
rotu_rabund = rotu_counts.divide(rotu_counts.sum(1), axis=0)

sample_rotu_bc_linkage = sp.cluster.hierarchy.linkage(
    rotu_rabund, method="average", metric="braycurtis", optimal_ordering=True
)

In [None]:
missing_samples = sorted(idxwhere(~rotu_counts.index.to_series().isin(sample.index)))
print(len(missing_samples), ", ".join(missing_samples))

In [None]:
x = rotu_rabund
row_colors = pd.DataFrame(
    dict(
        subj=sample.subject_id.map(subject_palette),
        swap=sample.index.to_series()
        .isin(["CF_11", "CF_15"])
        .replace({False: "grey", True: "black"}),
    )
)
row_linkage = sample_rotu_bc_linkage

sns.clustermap(
    rotu_rabund,
    norm=mpl.colors.PowerNorm(1 / 5),
    row_colors=row_colors,
    row_linkage=row_linkage,
)

In [None]:
pd.read_table(
    "data/group/een/r.proc.gtpro.species_depth.tsv")

In [None]:
gtpro_depth = (pd.read_table(
    "data/group/een/r.proc.gtpro.species_depth.tsv",
    index_col=['sample', "species_id"],
    )
    .depth.unstack(fill_value=0)
    .rename(columns=str, index=lambda x: "CF_" + str(int(x.split("_")[1])))
    .rename({'CF_15': 'CF_11', 'CF_11': 'CF_15'})  # Sample swap
)
gtpro_rabund = gtpro_depth.divide(gtpro_depth.sum(1), axis=0)

gtpro_rabund

In [None]:
motu_depth = (pd.read_table(
    "data/group/een/r.proc.gene99_new-v22-agg75.spgc_specgene-ref-t25-p95.species_depth.tsv",
    names=['sample', "species_id", 'depth'], index_col=['sample', "species_id"],
    )
    .depth.unstack(fill_value=0)
    .rename(columns=str, index=lambda x: "CF_" + str(int(x.split("_")[1])))
    .rename({'CF_15': 'CF_11', 'CF_11': 'CF_15'})  # Sample swap
)
motu_rabund = motu_depth.divide(motu_depth.sum(1), axis=0)

motu_rabund

In [None]:
x, y = align_indexes(motu_rabund, rotu_rabund)


x_linkage = linkage(x, method="average", metric="braycurtis", optimal_ordering=True)
y_linkage = linkage(y, method="average", metric="braycurtis", optimal_ordering=True)
colors = pd.DataFrame(
    dict(
        subj=sample.subject_id.map(subject_palette),
        swap=sample.index.to_series()
        .isin(["CF_11", "CF_15"])
        .replace({False: "grey", True: "black"}),
    )
)

x_pdist = pd.DataFrame(
    squareform(pdist(x, metric="braycurtis")), index=x.index, columns=x.index
)
sns.clustermap(
    x_pdist,
    row_linkage=y_linkage,
    col_linkage=x_linkage,
    row_colors=colors,
    col_colors=colors,
)

In [None]:
x, y = align_indexes(motu_rabund, gtpro_rabund)


x_linkage = linkage(x, method="average", metric="braycurtis", optimal_ordering=True)
y_linkage = linkage(y, method="average", metric="braycurtis", optimal_ordering=True)
colors = pd.DataFrame(
    dict(
        subj=sample.subject_id.map(subject_palette),
        swap=sample.index.to_series()
        .isin(["CF_11", "CF_15"])
        .replace({False: "grey", True: "black"}),
    )
)

x_pdist = pd.DataFrame(
    squareform(pdist(x, metric="braycurtis")), index=x.index, columns=x.index
)
sns.clustermap(
    x_pdist,
    row_linkage=y_linkage,
    col_linkage=x_linkage,
    row_colors=colors,
    col_colors=colors,
)

In [None]:
bins = np.linspace(0, 30_000, num=200)

fig, axs = plt.subplots(2, sharex=True)

for (title, x), ax in zip(
    dict(
        total_depth_by_sample=motu_depth.sum(1),
        total_depth_by_species=motu_depth.sum(0),
    ).items(),
    axs.flatten(),
):
    ax.hist(x, bins=np.logspace(-1, 5, num=100))
    ax.set_title(title)
    ax.set_xscale("log")
fig.tight_layout()

In [None]:
motu_rabund.mean().sort_values(ascending=False).head(20)

In [None]:
n_species = 10
top_motus = (
    (motu_rabund > 1e-5).sum().sort_values(ascending=False).head(n_species).index
)

fig, axs = plt.subplots(
    n_species, figsize=(5, 0.3 * n_species), sharex=True, sharey=True
)

bins = np.logspace(-8, 1, num=51)

for species_id, ax in zip(top_motus, axs):
    # ax.hist(rabund_subset[species_id], bins=bins, alpha=0.7)
    ax.hist(motu_rabund[species_id], bins=bins, alpha=0.7)
    ax.set_xscale("log")
    prevalence = (motu_rabund[species_id] > 1e-5).mean()
    ax.set_title("")
    # ax.set_xticks()
    # ax.set_yticks()
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.patch.set_alpha(0.0)
    for spine in ["left", "right", "top", "bottom"]:
        ax.spines[spine].set_visible(False)
    ax.annotate(
        f"{species_id} ({prevalence:0.0%})",
        xy=(0.05, 0.1),
        ha="left",
        xycoords="axes fraction",
    )
    ax.set_xlim(left=1e-9)
    ax.set_ylim(top=20)
    ax.axvline(1e-5, lw=1, linestyle=":", color="k")

ax.xaxis.set_visible(True)
ax.spines["bottom"].set_visible(True)
ax.set_xticks([1e-4, 1e-2, 1e-0])
ax.set_xticklabels(["0.01%", "1%", "100%"])
ax.set_xlabel("Relative Abundance")

# fig.subplots_adjust(hspace=-0.75)

In [None]:
def parse_taxonomy_string(taxonomy_string):
    values = taxonomy_string.split(";")
    return pd.Series(values, index=["d__", "p__", "c__", "o__", "f__", "g__", "s__"])

In [None]:
motu_taxonomy_inpath = "ref/uhgg_genomes_all_v2.tsv"

_motu_taxonomy = (
    pd.read_table(motu_taxonomy_inpath)[lambda x: x.Genome == x.Species_rep]
    .assign(species_id=lambda x: "1" + x.MGnify_accession.str.split("-").str[2])
    .set_index("species_id")
)

# motu_lineage_string = _motu_taxonomy.Lineage

motu_taxonomy = _motu_taxonomy.Lineage.apply(
    parse_taxonomy_string
)  # .assign(taxonomy_string=motu_lineage_string)
motu_taxonomy

In [None]:
for _species_id in top_motus.astype(str):
    print(_species_id, ":", ";".join(motu_taxonomy.loc[_species_id].values))

In [None]:
sotu_depth = []
missing_files = []
for species_id in motu_depth.columns:
    path = f"data/group/een/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.comm.tsv"
    try:
        d = (
            pd.read_table(path, index_col=["sample", "strain"])
            .squeeze()
            .unstack()
            .rename(columns=str, index=lambda x: "CF_" + str(int(x.split("_")[1])))
            .rename({'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap.
        )
    except FileNotFoundError:
        missing_files.append(path)
        d = pd.DataFrame([])
    _keep_strains = idxwhere(d.sum() > 0.05)
    assert d.index.isin(motu_depth.index).all()
    d = d.reindex(index=motu_depth.index, columns=_keep_strains, fill_value=0)
    d = d.assign(__other=lambda x: 1 - x.sum(1)).rename(columns={"__other": -1})
    d[d < 0] = 0
    d = d.divide(d.sum(1), axis=0)
    d = d.multiply(motu_depth[species_id], axis=0)
    d = d.rename(columns=lambda s: f"{species_id}_{s}")
    sotu_depth.append(d)
sotu_depth = pd.concat(sotu_depth, axis=1)
sotu_rabund = sotu_depth.divide(sotu_depth.sum(1), axis=0)
len(motu_depth.columns), len(missing_files)

## Permutation Test Prototype #3

This permutation test differs from earlier versions in that
- it permutes the pair labels,
- it fits a cubic spline (with 4 knots) to the time data
- it's run on the compositional data (instead of metagenomics)

In [None]:
def turnover_analysis3(
    _rabund,
    _meta,
    _rabund_ctrl=None,  # e.g. control for species-level turnover
    pair_type_palette=pair_type_palette,
    pair_type_marker_palette=pair_type_marker_palette,
    pair_type_linestyle_palette=pair_type_linestyle_palette,
    subject_palette=subject_palette,
    subject_order=subject_order,
    pair_type_order=pair_type_order,
    _dmat=None,
    n_perm=999,
):
    # Select data
    _rabund, _meta = lib.pandas_util.align_indexes(_rabund, _meta)

    # Calculate pairwise comparisons
    if _dmat is None:
        _dmat = pdist(_rabund, metric="braycurtis")
        if _rabund_ctrl is not None:
            # NOTE: bc stands for bray-curtis and is used as a stand-in for
            # dissimilarity throughout because
            # the default is to use bc dissimilarity.
            bc_cdist_ctrl = pdist(_rabund_ctrl.loc[_rabund.index], metric='braycurtis')
            _dmat = _dmat - bc_cdist_ctrl
    else:
        assert _rabund_ctrl is None
        assert _dmat.shape[0] == _rabund.shape[0]
        assert _dmat.shape[1] == _rabund.shape[0]
        _dmat = squareform(_dmat)
        # But it's the users responsibility to sort the dmat correctly.

    time_cdist = pdist(
        _meta[["collection_date_relative_een_end"]], metric=lambda x, y: np.abs(x - y)
    )
    depth_min_cdist = pdist(_meta[['depth']], metric=lambda x, y: min(x, y))
    diff_subject_cdist = pdist(_meta[["subject_id"]], metric=lambda x, y: x != y)
    same_subject_cdist = (1 - diff_subject_cdist).astype(bool)
    type_transition_indicator = pdist(
        _meta[["diet_or_media"]],
        metric=lambda x, y: (x == "PostEEN").astype(float)
        + (y == "PostEEN").astype(float),
    )
    # same_type_cdist = (1 - diff_type_cdist).astype(bool)
    pairs = (
        pd.DataFrame(
            squareform(diff_subject_cdist),
            index=_meta.index,
            columns=_meta.index,
        )
        .rename_axis(index="sampleA", columns="sampleB")
        .unstack()
        .index.to_frame()
        .assign(pair=lambda x: x[["sampleA", "sampleB"]].apply(tuple, axis=1))
        .pair.unstack()
    )
    pairs = pd.Series(pairs.values[np.triu_indices_from(pairs.values, k=1)])
    d0 = pd.DataFrame(
        dict(
            sampleA=pairs.str[0],
            sampleB=pairs.str[1],
            bc=_dmat,  # Really this should be generic for any dissimilarity, not only BC.
            same_subject=same_subject_cdist,
            type_transition_indicator=type_transition_indicator,
            diff_type=type_transition_indicator == 1,
            time_delta=time_cdist,
            depth_min=depth_min_cdist,
        )
    ).assign(
        pair_type=lambda x: x.type_transition_indicator.map(
            {0: "EEN", 1: "Transition", 2: "PostEEN"}
        ),
        subject_id=lambda x: x.sampleA.map(_meta.subject_id),
        subject_id_other=lambda x: x.sampleB.map(_meta.subject_id),
    )
    d1 = d0[lambda x: x.same_subject]

    # Plot and observed relationship
    fig, axs = plt.subplots(1, 2)
    ax = axs[0]
    sns.boxenplot(
        y="bc",
        x="same_subject",
        hue="pair_type",
        data=d0,
        ax=ax,
        palette=pair_type_palette,
        order=[False, True],
        hue_order=pair_type_order,
    )
    # ax.legend()
    ax = axs[1]
    sns.stripplot(
        y="bc",
        x="pair_type",
        hue="subject_id",
        data=d1,
        ax=ax,
        palette=subject_palette,
        order=pair_type_order,
    )
    ax.legend_.set_visible(False)

    # Fit observed relationship
    # formula = "bc ~ 0 + C(pair_type) + cr(time_delta, 4) + C(subject_id, Sum) + cr(depth_min, 4)"
    formula = "bc ~ 0 + C(pair_type) + cr(time_delta, 4) + C(subject_id, Sum)"
    fit = smf.ols(formula, data=d1).fit()
    observed_stat = fit.params
    observed_stat["PostEEN - EEN"] = (
        observed_stat["C(pair_type)[PostEEN]"] - observed_stat["C(pair_type)[EEN]"]
    )
    observed_stat["Transition - EEN"] = (
        observed_stat["C(pair_type)[Transition]"] - observed_stat["C(pair_type)[EEN]"]
    )
    observed_stat["Transition - PostEEN"] = (
        observed_stat["C(pair_type)[Transition]"]
        - observed_stat["C(pair_type)[PostEEN]"]
    )
    print(observed_stat[["Transition - EEN", "Transition - PostEEN", "PostEEN - EEN"]])

    # Calculate permutations
    np.random.seed(1)
    perm_stat = {}
    for i in tqdm(range(n_perm)):
        # Permute sample-pair labels within subjects
        _perm_type = (
            d0.assign(mgen_id=lambda x: x.index)
            .groupby("subject_id")
            .pair_type.transform(np.random.permutation)
        )
        perm_fit = smf.ols(
            formula,
            data=d0.assign(
                pair_type=_perm_type,
            )[lambda x: x.same_subject],
        ).fit()
        _stat = perm_fit.params
        _stat["PostEEN - EEN"] = (
            _stat["C(pair_type)[PostEEN]"] - _stat["C(pair_type)[EEN]"]
        )
        _stat["Transition - EEN"] = (
            _stat["C(pair_type)[Transition]"] - _stat["C(pair_type)[EEN]"]
        )
        _stat["Transition - PostEEN"] = (
            _stat["C(pair_type)[Transition]"] - _stat["C(pair_type)[PostEEN]"]
        )
        perm_stat[i] = _stat
    perm_stat = pd.DataFrame(perm_stat).T

    perm_pvalues = _calculate_2tailed_pvalue_from_perm(observed_stat, perm_stat)

    fig, axs = plt.subplots(3)
    # Plot permutation tests
    ax = axs[0]
    param_name = "Transition - EEN"
    ax.set_title(param_name)
    ax.hist(perm_stat[param_name], bins=20)
    ax.axvline(observed_stat[param_name])
    ax.annotate(
        perm_pvalues[param_name],
        xy=(0.95, 0.95),
        xycoords="axes fraction",
        ha="right",
        va="top",
    )
    ax = axs[1]
    param_name = "Transition - PostEEN"
    ax.set_title(param_name)
    ax.hist(perm_stat[param_name], bins=20)
    ax.axvline(observed_stat[param_name])
    hypoth = perm_stat[param_name] > observed_stat[param_name]
    null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
    ax.annotate(
        perm_pvalues[param_name],
        xy=(0.95, 0.95),
        xycoords="axes fraction",
        ha="right",
        va="top",
    )
    ax = axs[2]
    param_name = "PostEEN - EEN"
    ax.set_title(param_name)
    ax.hist(perm_stat[param_name], bins=20)
    ax.axvline(observed_stat[param_name])
    hypoth = perm_stat[param_name] > observed_stat[param_name]
    null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
    ax.annotate(
        perm_pvalues[param_name],
        xy=(0.95, 0.95),
        xycoords="axes fraction",
        ha="right",
        va="top",
    )
    fig.tight_layout()

    d2 = d1.assign(
        # predict=fit.predict(),
        # predict_mean_subject=fit.predict(d1.assign(subject_id=_arbitrary_subject))
        # - fit.params[f"C(subject_id, Sum)[S.{_arbitrary_subject}]"],
        resid_pearson=fit.resid_pearson,
        influence=fit.get_influence().summary_frame().cooks_d,
    ).sort_values("time_delta")

    fig, ax = plt.subplots()
    art = ax.scatter(
        "time_delta",
        "resid_pearson",
        c="influence",
        data=d2,
    )
    fig.colorbar(art, label="Cook's D")
    # ax.set_ylabel("Residual BC\n(standardized)")
    ax.set_xlabel("Within-Subjects Days between Samples")
    ax.set_xscale("symlog")

    fit.summary()

    fig, ax = plt.subplots()
    ax.set_title("Within-subject Pairwise Turnover")
    for pair_type in pair_type_order:
        d3 = d1[lambda x: (x.pair_type == pair_type)]
        ax.scatter(
            "time_delta",
            "bc",
            label="__nolegend__",
            color=pair_type_palette[pair_type],
            data=d3,
            marker=pair_type_marker_palette[pair_type],
            edgecolor='grey', lw=0.5,
        )

    _arbitrary_subject = d1.subject_id.unique()[1]
    predict_data = pd.DataFrame(
        product(
            [_arbitrary_subject],
            ["EEN", "PostEEN", "Transition"],
            np.logspace(1.0, 2.6),
            [1.0]
        ),
        columns=["subject_id", "pair_type", "time_delta", "depth_min"],
    )
    predict_data = predict_data.assign(
        prediction=fit.predict(predict_data),
        predict_mean_subject=lambda x: x.prediction
        - fit.params[f"C(subject_id, Sum)[S.{_arbitrary_subject}]"],
    )
    for pair_type in pair_type_order:
        d4 = predict_data[lambda x: x.pair_type == pair_type]
        left, right = d1[lambda x: x.pair_type == pair_type].time_delta.quantile(
            [0.05, 0.95]
        )
        ax.plot(
            "time_delta",
            "predict_mean_subject",
            label="__nolegend__",
            data=d4[lambda x: (x.time_delta > left) & (x.time_delta < right)],
            color='black',
            linestyle=pair_type_linestyle_palette[pair_type],
        )
    ax.set_ylabel("Pairwise Dissimilarity")
    ax.set_xlabel("Days between Samples")
    ax.set_xscale("symlog", linthresh=1e-1)

    # fig, ax = plt.subplots(figsize=(2.5, 1.5))
    for pair_type in pair_type_order:
        ax.plot(
            [],
            [],
            label=pair_type,
            color=pair_type_palette[pair_type],
            marker=pair_type_marker_palette[pair_type],
            linestyle=pair_type_linestyle_palette[pair_type],
        )
    ax.legend()

    d2 = d1.assign(
        predict=fit.predict(),
        predict_mean_subject=fit.predict(d1.assign(subject_id=_arbitrary_subject))
        - fit.params[f"C(subject_id, Sum)[S.{_arbitrary_subject}]"],
        resid_pearson=fit.resid_pearson,
        influence=fit.get_influence().summary_frame().cooks_d,
    ).sort_values("time_delta")

    fig, ax = plt.subplots()
    ax.set_title("Within-subject Pairwise Turnover")
    for subject_id, pair_type in product(subject_order, pair_type_order):
        d3 = d2[lambda x: (x.subject_id == subject_id) & (x.pair_type == pair_type)]
        ax.scatter(
            "time_delta",
            "bc",
            label="__nolegend__",
            color=subject_palette[subject_id],
            data=d3,
            marker=pair_type_marker_palette[pair_type],
            alpha=0.7,
        )

    predict_data = pd.DataFrame(
        product(
            [_arbitrary_subject],
            ["EEN", "PostEEN", "Transition"],
            np.logspace(1.1, 2.6),
            [1.0],
        ),
        columns=["subject_id", "pair_type", "time_delta", "depth_min"],
    )
    predict_data = predict_data.assign(
        prediction=fit.predict(predict_data),
        predict_mean_subject=lambda x: x.prediction
        - fit.params[f"C(subject_id, Sum)[S.{_arbitrary_subject}]"],
    )

    for pair_type in pair_type_order:
        d4 = predict_data[lambda x: x.pair_type == pair_type]
        left, right = d1[lambda x: x.pair_type == pair_type].time_delta.quantile(
            [0.05, 0.95]
        )
        ax.plot(
            "time_delta",
            "predict_mean_subject",
            label="__nolegend__",
            data=d4[lambda x: (x.time_delta > left) & (x.time_delta < right)],
            color="black",
            linestyle=pair_type_linestyle_palette[pair_type],
        )
    ax.set_ylabel("Bray-Curtis Dissimilarity")
    ax.set_xlabel("Days between Samples")
    ax.set_xscale("symlog", linthresh=1e-1)

    fig, ax = plt.subplots(figsize=(5, 2))
    # Legends
    for subject_id in subject_order:
        ax.scatter(
            [], [], label=subject_id, color=subject_palette[subject_id], marker="s"
        )
    ax.legend(ncols=5, title="Subject")

    fig, ax = plt.subplots(figsize=(2.5, 1.5))
    for pair_type in pair_type_order:
        ax.plot(
            [],
            [],
            label=pair_type,
            color="black",
            marker=pair_type_marker_palette[pair_type],
            linestyle=pair_type_linestyle_palette[pair_type],
        )
    ax.legend()
    return fit.summary()

### Metagenomic Species Turnover

In [None]:
turnover_analysis3(
    _rabund=motu_rabund,
    _meta=sample[
        sample.diet_or_media.isin(["EEN", "PostEEN"])
        & ~sample.collection_date_relative_een_end.isna()
    ].assign(depth=0),
)

### zOTU Species Turnover

In [None]:
turnover_analysis3(
    _rabund=rotu_rabund,
    _meta=sample[
        sample.diet_or_media.isin(["EEN", "PostEEN"])
        & ~sample.collection_date_relative_een_end.isna()
    ].assign(depth=0),
)

### zOTU Species Turnover using Generalized Unifrac

In [None]:
dmat_gunifrac = pd.read_excel('raw/een-mgen/2023-09-28_deborah.haecker@tum-create.edu.sg/distance-matrix-gunif_Byron.xlsx', index_col='ID')

In [None]:
_rabund = rotu_rabund
_meta=sample[
        sample.diet_or_media.isin(["EEN", "PostEEN"])
        & ~sample.collection_date_relative_een_end.isna()
    ].assign(depth=0)

_rabund, _meta = lib.pandas_util.align_indexes(_rabund, _meta)

_dmat = dmat_gunifrac.loc[_meta.index, _meta.index]

turnover_analysis3(
    _rabund=_rabund,
    _meta=_meta,
    _dmat=_dmat,
    n_perm=999,
)

### zOTU Species Turnover (but same samples as Metagenomic Species)

In [None]:
turnover_analysis3(
    _rabund=rotu_rabund.loc[motu_rabund.index],
    _meta=sample[
        sample.diet_or_media.isin(["EEN", "PostEEN"])
        & ~sample.collection_date_relative_een_end.isna()
    ].assign(depth=0),
)

### zOTU Species Turnover using Generalized Unifrac (but same samples as Metagenomic Species)

In [None]:
_rabund = rotu_rabund.loc[motu_rabund.index]
_meta=sample[
        sample.diet_or_media.isin(["EEN", "PostEEN"])
        & ~sample.collection_date_relative_een_end.isna()
    ].assign(depth=0)

_rabund, _meta = lib.pandas_util.align_indexes(_rabund, _meta)

_dmat = dmat_gunifrac.loc[_meta.index, _meta.index]

turnover_analysis3(
    _rabund=_rabund,
    _meta=_meta,
    _dmat=_dmat,
)

### Strain Turnover

In [None]:
turnover_analysis3(
    _rabund=sotu_rabund,
    _meta=sample[
        sample.diet_or_media.isin(["EEN", "PostEEN"])
        & ~sample.collection_date_relative_een_end.isna()
    ].assign(depth=0),
)

### Strain Turnover while Controlling for Species

In [None]:
turnover_analysis3(
    _rabund=sotu_rabund,
    _rabund_ctrl=motu_rabund,
    _meta=sample[
        sample.diet_or_media.isin(["EEN", "PostEEN"])
        & ~sample.collection_date_relative_een_end.isna()
    ].assign(depth=0),
)

### Individual strain turnover

#### g__Escherichia coli_D

In [None]:
_species_id = "102506"
_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)
_rabund = _world.community.sel(sample=_world.metagenotype.mean_depth() > 1e-1).to_pandas()
_meta=sample[
        sample.diet_or_media.isin(["EEN", "PostEEN"])
        & ~sample.collection_date_relative_een_end.isna()
    ].assign(depth=0)

turnover_analysis3(
    _rabund=_rabund,
    _meta=_meta,
)

#### TODO: C. bolteae

In [None]:
_species_id = "101493"
_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)
_rabund = _world.community.sel(sample=_world.metagenotype.mean_depth() > 1e-1).to_pandas()
_meta=sample[
        sample.diet_or_media.isin(["EEN", "PostEEN"])
        & ~sample.collection_date_relative_een_end.isna()
    ].assign(depth=0)

turnover_analysis3(
    _rabund=_rabund,
    _meta=_meta,
)

In [None]:
_species_id = "102506"
_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)
_rabund = _world.community.sel(sample=_world.metagenotype.mean_depth() > 1e-1).to_pandas()
_meta=sample[
        sample.diet_or_media.isin(["EEN", "PostEEN"])
        & ~sample.collection_date_relative_een_end.isna()
    ].assign(depth=0)

_rabund, _meta = lib.pandas_util.align_indexes(_rabund, _meta)

_dmat = _world.unifrac_pdist(discretized=False).loc[_meta.index, _meta.index]

turnover_analysis3(
    _rabund=_rabund,
    _meta=_meta,
    _dmat=_dmat,
)

#### s__Eggerthella lenta

In [None]:
_species_id = "102544"
_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)
_rabund = _world.community.sel(sample=_world.metagenotype.mean_depth() > 1e-1).to_pandas()
_meta = sample[
        sample.diet_or_media.isin(["EEN", "PostEEN"])
        & ~sample.collection_date_relative_een_end.isna()
    ].assign(depth=0)

turnover_analysis3(
    _rabund=_rabund,
    _meta=_meta,
)

In [None]:
_species_id = "102544"
_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)
_rabund = _world.community.sel(sample=_world.metagenotype.mean_depth() > 1e-1).to_pandas()
_meta=sample[
        sample.diet_or_media.isin(["EEN", "PostEEN"])
        & ~sample.collection_date_relative_een_end.isna()
    ].assign(depth=0)

_rabund, _meta = lib.pandas_util.align_indexes(_rabund, _meta)

_dmat = _world.unifrac_pdist(discretized=False).loc[_meta.index, _meta.index]

turnover_analysis3(
    _rabund=_rabund,
    _meta=_meta,
    _dmat=_dmat,
)

#### s__Flavonifractor plautii

In [None]:
_species_id = "100099"
_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)
_rabund = _world.community.sel(sample=_world.metagenotype.mean_depth() > 1e-1).to_pandas()
_meta = sample[
        sample.diet_or_media.isin(["EEN", "PostEEN"])
        & ~sample.collection_date_relative_een_end.isna()
    ].assign(depth=0)

turnover_analysis3(
    _rabund=_rabund,
    _meta=_meta,
)

In [None]:
_species_id = "102544"
_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)
_rabund = _world.community.sel(sample=_world.metagenotype.mean_depth() > 1e-1).to_pandas()
_meta=sample[
        sample.diet_or_media.isin(["EEN", "PostEEN"])
        & ~sample.collection_date_relative_een_end.isna()
    ].assign(depth=0)

_rabund, _meta = lib.pandas_util.align_indexes(_rabund, _meta)

_dmat = _world.unifrac_pdist(discretized=False).loc[_meta.index, _meta.index]

turnover_analysis3(
    _rabund=_rabund,
    _meta=_meta,
    _dmat=_dmat,
)

#### s__Bacteroides uniformis

In [None]:
_species_id = "101346"
_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)
_rabund = _world.community.sel(sample=_world.metagenotype.mean_depth() > 1e-1).to_pandas()
_meta = sample[
        sample.diet_or_media.isin(["EEN", "PostEEN"])
        & ~sample.collection_date_relative_een_end.isna()
    ].assign(depth=0)

turnover_analysis3(
    _rabund=_rabund,
    _meta=_meta,
)

#### TODO1

In [None]:
_species_id = "101378"
_world = (
    sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc",
        validate=False,
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap
)
_rabund = _world.community.sel(sample=_world.metagenotype.mean_depth() > 1e-1).to_pandas()
_meta = sample[
        sample.diet_or_media.isin(["EEN", "PostEEN"])
        & ~sample.collection_date_relative_een_end.isna()
    ].assign(depth=0)

turnover_analysis3(
    _rabund=_rabund,
    _meta=_meta,
)