## Preamble

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

### Imports

In [None]:
import os
import subprocess
import time
from itertools import chain, product
from tempfile import mkstemp
from warnings import filterwarnings

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import sfacts as sf
import statsmodels.formula.api as smf
import xarray as xr
from fastcluster import linkage
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.spatial.distance import pdist, squareform
from statsmodels.graphics.regressionplots import influence_plot
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
import lib.thisproject.data
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

In [None]:
def _calculate_2tailed_pvalue_from_perm(obs, perms):
    hypoth_left = perms > obs
    hypoth_right = perms < obs
    null_p_left = (hypoth_left.sum() + 1) / (len(hypoth_left) + 1)
    null_p_right = (hypoth_right.sum() + 1) / (len(hypoth_right) + 1)
    return np.minimum(null_p_left, null_p_right) * 2

In [None]:
def linkage_order(linkage, labels):
    return labels[sp.cluster.hierarchy.to_tree(linkage).pre_order(lambda x: x.id)]


def is_prime(n):
    if n <= 1:
        return False
    for i in range(2, int(n**0.5) + 1):
        if n % i == 0:
            return False
    return True


def iterate_primes_up_to(n, return_index=False):
    n = int(np.ceil(n))
    idx = 0
    for i in range(n):
        if is_prime(i):
            if return_index:
                yield (idx, i)
            else:
                yield i
            idx += 1


def maximally_shuffled_order(sorted_order):
    n = len(sorted_order)
    primes_list = list(iterate_primes_up_to(np.sqrt(n)))
    table = pd.DataFrame(np.arange(n), index=sorted_order, columns=["original_order"])
    for prime in primes_list:
        table[prime] = table.original_order % prime
    table.sort_values(primes_list).original_order.values
    table = table.assign(new_order=table.sort_values(primes_list).original_order.values)
    z = table.sort_values("new_order").original_order.values
    table["delta"] = [np.nan] + list(z[1:] - z[:-1])
    return table.sort_values("new_order").index.to_list()

In [None]:
pair_type_palette = {
    "EEN": "teal",
    "PostEEN": "mediumblue",
    "Transition": "blueviolet",
}

sample_type_palette = {
    "EEN": "lightgreen",
    "PostEEN": "lightblue",
    "InVitro": "plum",
    "PreEEN": "lightpink",
}

subject_order = [
    "A",
    "B",
    "H",
    "C",
    "D",
    "E",
    "F",
    "G",
    "K",
    "L",
    "M",
    "N",
    "O",
    "P",
    "Q",
    "R",
    "S",
    "T",
    "U",
]

# NOTE: Requires a dummy value because I want exactly 20 items.
subject_palette = lib.plot.construct_ordered_palette(
    subject_order + [f"dummy{i}" for i in range(20 - len(subject_order))], cm="tab20"
)
pair_type_order = ["EEN", "Transition", "PostEEN"]
pair_type_marker_palette = {"EEN": "s", "Transition": ">", "PostEEN": "o"}
pair_type_linestyle_palette = {"EEN": ":", "Transition": "-.", "PostEEN": "-"}

In [None]:
species_depth = (
    pd.read_table(
        "data/group/een/r.proc.gtpro.species_depth.tsv",
        index_col=["sample", "species_id"],
    )
    .depth.unstack(fill_value=0)
    .rename(columns=str)
)

bins = np.linspace(0, 30_000, num=200)

fig, axs = plt.subplots(2, sharex=True)

for (title, x), ax in zip(
    dict(
        total_depth_by_sample=species_depth.sum(1),
        total_depth_by_species=species_depth.sum(0),
    ).items(),
    axs.flatten(),
):
    ax.hist(x, bins=np.logspace(-1, 5, num=100))
    ax.set_title(title)
    ax.set_xscale("log")
fig.tight_layout()

In [None]:
species_rabund = species_depth.divide(species_depth.sum(1), axis=0)
species_rabund.mean().sort_values(ascending=False).head(20)

In [None]:
species_rabund = species_depth.divide(species_depth.sum(1), axis=0)

In [None]:
n_species = 10
top_species = (
    (species_rabund > 1e-5).sum().sort_values(ascending=False).head(n_species).index
)

fig, axs = plt.subplots(
    n_species, figsize=(5, 0.3 * n_species), sharex=True, sharey=True
)

bins = np.logspace(-8, 1, num=51)

for species_id, ax in zip(top_species, axs):
    # ax.hist(rabund_subset[species_id], bins=bins, alpha=0.7)
    ax.hist(species_rabund[species_id], bins=bins, alpha=0.7)
    ax.set_xscale("log")
    prevalence = (species_rabund[species_id] > 1e-5).mean()
    ax.set_title("")
    # ax.set_xticks()
    # ax.set_yticks()
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.patch.set_alpha(0.0)
    for spine in ["left", "right", "top", "bottom"]:
        ax.spines[spine].set_visible(False)
    ax.annotate(
        f"{species_id} ({prevalence:0.0%})",
        xy=(0.05, 0.1),
        ha="left",
        xycoords="axes fraction",
    )
    ax.set_xlim(left=1e-9)
    ax.set_ylim(top=20)
    ax.axvline(1e-5, lw=1, linestyle=":", color="k")

ax.xaxis.set_visible(True)
ax.spines["bottom"].set_visible(True)
ax.set_xticks([1e-4, 1e-2, 1e-0])
ax.set_xticklabels(["0.01%", "1%", "100%"])
ax.set_xlabel("Relative Abundance")

# fig.subplots_adjust(hspace=-0.75)

In [None]:
species_taxonomy = lib.thisproject.data.load_species_taxonomy(
    "ref/gtpro/species_taxonomy_ext.tsv"
)
species_taxonomy.loc["102506", "s__"] = "s__Escherichia coli"

In [None]:
for _species_id in top_species.astype(str):
    print(_species_id, ":", species_taxonomy.taxonomy_string.loc[_species_id])

In [None]:
species_taxonomy[lambda x: x.s__.str.endswith("hansenii")]

In [None]:
for _species_id in ["102544", "102506", "101303", "100150", "102330", "101704"]:
    print(
        _species_id,
        (species_rabund[_species_id] > 0.0001).mean().round(2),
        (species_rabund[_species_id] > 0.001).mean().round(2),
        species_taxonomy.loc[_species_id].s__,
        sep="\t\t",
    )

In [None]:
n_species = 20
top_species = (
    (species_rabund > 1e-3).sum().sort_values(ascending=False).head(n_species).index
)

fig, axs = plt.subplots(
    n_species, figsize=(5, 0.3 * n_species), sharex=True, sharey=True
)

bins = np.logspace(-8, 1, num=51)

for species_id, ax in zip(top_species, axs):
    # ax.hist(rabund_subset[species_id], bins=bins, alpha=0.7)
    ax.hist(species_rabund[species_id], bins=bins, alpha=0.7)
    ax.set_xscale("log")
    prevalence = (species_rabund[species_id] > 1e-3).mean()
    ax.set_title("")
    # ax.set_xticks()
    # ax.set_yticks()
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.patch.set_alpha(0.0)
    for spine in ["left", "right", "top", "bottom"]:
        ax.spines[spine].set_visible(False)
    ax.annotate(
        f"{species_id} ({prevalence:0.0%})",
        xy=(0.05, 0.1),
        ha="left",
        xycoords="axes fraction",
    )
    ax.set_xlim(left=1e-9)
    ax.set_ylim(top=20)
    ax.axvline(1e-5, lw=1, linestyle=":", color="k")

ax.xaxis.set_visible(True)
ax.spines["bottom"].set_visible(True)
ax.set_xticks([1e-4, 1e-2, 1e-0])
ax.set_xticklabels(["0.01%", "1%", "100%"])
ax.set_xlabel("Relative Abundance")

# fig.subplots_adjust(hspace=-0.75)

In [None]:
strain_depth = []
missing_files = []
for species_id in species_depth.columns:
    path = f"data/group/een/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.comm.tsv"
    try:
        d = pd.read_table(path, index_col=["sample", "strain"]).squeeze().unstack()
    except FileNotFoundError:
        missing_files.append(path)
        d = pd.DataFrame([])
    _keep_strains = idxwhere(d.sum() > 0.05)
    d = d.reindex(index=species_depth.index, columns=_keep_strains, fill_value=0)
    d = d.assign(__other=lambda x: 1 - x.sum(1)).rename(columns={"__other": -1})
    d[d < 0] = 0
    d = d.divide(d.sum(1), axis=0)
    d = d.multiply(species_depth[species_id], axis=0)
    d = d.rename(columns=lambda s: f"{species_id}_{s}")
    strain_depth.append(d)
strain_depth = pd.concat(strain_depth, axis=1)
strain_rabund = strain_depth.divide(strain_depth.sum(1), axis=0)
len(species_depth.columns), len(missing_files)

In [None]:
plt.hist(strain_rabund.values.flatten() + 1e-10, bins=np.logspace(-10, 0))
plt.yscale("log")
plt.xscale("log")

In [None]:
for _species_id in top_species.astype(str):
    print(_species_id, ":", species_taxonomy.taxonomy_string.loc[_species_id])

In [None]:
meta = (
    pd.read_table("meta/een-mgen/stool.tsv")
    # .rename(columns={'Seq-Name': 'sample', 'CED/Patient-recoded': 'subject_id', 'sampleDate': 'date', 'Diet (=PreEEN, EEN, PostEEN)': 'sample_type'})
    # .assign(
    #     date=lambda x: pd.to_datetime(x.date),
    #     sample_type=lambda x: x.sample_type.fillna('???')
    # )
    .set_index("mgen_id")
    # .sort_values(['subject_id', 'date', 'sample_type'])
)

# FIXME: Metadata seems to include a swap in the metagenomic data of CF_11 and CF_15.
meta = meta.rename({"CF_11": "CF_15", "CF_15": "CF_11"})
meta

In [None]:
invitro_meta = (
    pd.read_table("meta/een-mgen/microcosm.tsv")
    .set_index("mgen_id")
    .rename(columns={"inoculum_subject_id": "subject_id"})
)
invitro_meta

In [None]:
meta_all = pd.concat(
    [
        meta.assign(
            label=lambda x: x.assign(idx=x.index)[
                ["idx", "collection_date_relative_een_end", "sample_type"]
            ].apply(tuple, axis=1)
        ),
        invitro_meta.assign(
            collection_date_relative_een_end=np.inf,
            sample_type="InVitro",
            invitro_info=lambda x: x.inoculum_mgen_id.map(lambda s: f"invitro[{s}]"),
            label=lambda x: x.assign(idx=x.index)[["idx", "invitro_info"]].apply(
                tuple, axis=1
            ),
        ),
    ]
)

In [None]:
assert meta.index.isin(species_rabund.index).all()

In [None]:
meta_all[["subject_id", "sample_type"]].value_counts().unstack(
    fill_value=0
).sort_values("EEN", ascending=False).head(20)

### Confirm Sample Swap

In [None]:
d = species_rabund.loc[:, lambda x: (x > 1e-5).sum() > 1]
sample_linkage_strain_rabund = linkage(
    strain_rabund, method="average", metric="braycurtis"
)
row_colors = pd.DataFrame(
    dict(
        subj=meta_all.subject_id.map(subject_palette),
        type=meta_all.sample_type.map(sample_type_palette),
        swap=meta_all.index.to_series()
        .isin(["CF_11", "CF_15"])
        .map({True: "black", False: "grey"}),
    )
)

cg = sns.clustermap(
    d.T,
    norm=mpl.colors.PowerNorm(1 / 5),
    col_linkage=sample_linkage_strain_rabund,
    metric="cosine",
    xticklabels=1,
    figsize=(15, 10),
    col_colors=row_colors,
    dendrogram_ratio=(0.05, 0.05),
    yticklabels=0,
)
cg.ax_cbar.set_visible(False)

In [None]:
d = strain_rabund.loc[:, lambda x: (x > 1e-5).sum() > 1]
sample_linkage_strain_rabund = linkage(
    strain_rabund, method="average", metric="braycurtis"
)
row_colors = pd.DataFrame(
    dict(
        subj=meta_all.subject_id.map(subject_palette),
        type=meta_all.sample_type.map(sample_type_palette),
        swap=meta_all.index.to_series()
        .isin(["CF_11", "CF_15"])
        .map({True: "black", False: "grey"}),
    )
)

cg = sns.clustermap(
    d.T,
    norm=mpl.colors.PowerNorm(1 / 5),
    col_linkage=sample_linkage_strain_rabund,
    metric="cosine",
    xticklabels=1,
    figsize=(15, 10),
    col_colors=row_colors,
    dendrogram_ratio=(0.05, 0.05),
    yticklabels=0,
)
cg.ax_cbar.set_visible(False)

## Load Species Metagenotype / Strain Deconvolution

## Focal Species Plots

In [None]:
def strains_in_subjects(_species_rabund, _species_id, _world, _meta, _species_taxonomy=species_taxonomy, savefig=False, _subject_order=subject_order, _outpath_pattern="fig/een_{_species_name}_strain_tracking.pdf"):
        print(_species_taxonomy.loc[_species_id])
        _species_name = _species_taxonomy.loc[_species_id].s__[len("s__") :].replace(" ", "_")
        _species_rabund = species_rabund[_species_id]
        _world = sf.data.World.load(
            # f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
            # f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts49-s85-seed0.world.nc"
            f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts37-s85-seed0.world.nc"
        )
        _frac = _world.drop_low_abundance_strains(0.05).community.to_pandas()
        
        # # Arbitrarily ordered palette:
        # _strain_list = list(_frac.columns)
        # _strain_list.remove(-1)  # Drop "other" strain.
        # strain_palette = lib.plot.construct_ordered_palette(_strain_list, cm="rainbow")
        
        # Genotype similarity ordered palette:
        _world_drop_low_abund = _world.drop_low_abundance_strains(0.05)
        _strain_list = list(
            linkage_order(
                _world_drop_low_abund.genotype.linkage(optimal_ordering=True),
                _world_drop_low_abund.strain.values,
            )
        )
        _strain_list.remove(-1)  # Drop "other" strain.
        strain_palette = lib.plot.construct_ordered_palette(
            _strain_list,
            cm="rainbow",
        )
        
        # # Construct an ordered palette, but use the order to AVOID closely correlated strains (e.g. found in the same subject)
        # # showing up with similar colors.
        # _strain_list = linkage_order(
        #     sp.cluster.hierarchy.linkage(
        #         _frac.groupby(meta.subject_id).mean().T,
        #         method="average",
        #         metric="cosine",
        #         optimal_ordering=True,
        #     ),
        #     index=_frac.columns,
        # )
        # _strain_list.remove(-1)  # Drop "other" strain.
        # strain_palette = lib.plot.construct_ordered_palette(
        #     maximally_shuffled_order(
        #         _strain_list
        #     ),
        #     cm="rainbow",
        # )
        
        d0 = (
            _meta
            .assign(
                # has_strain_deconv=lambda x: x.index.isin(_comm.index),
                species_rabund=_species_rabund,
            )
            .join(_frac)
            .sort_values(["subject_id", "collection_date_relative_een_end", "sample_type"])
        )
        
        fig, axs = lib.plot.subplots_grid(
            ncols=4,
            naxes=len(_subject_order),
            ax_width=5,
            ax_height=4,
        )
        fig.suptitle(_species_name)
        for subject_id, ax in zip(_subject_order, axs.flatten()):
            ax.set_title(subject_id)
            twin_ax = ax.twinx()
            d1 = d0[lambda x: x.subject_id == subject_id].set_index("label")
            if d1.empty:
                continue
            d1[[-1] + _strain_list].plot(
                kind="bar",
                width=0.95,
                stacked=True,
                color=strain_palette,
                ax=ax,
            )
            d1.species_rabund.plot(kind="line", ax=twin_ax, color="k")
            ax.legend_.set_visible(False)
            ax.set_ylim(0, 1)
            ax.set_ylabel("strain fraction")
            ax.set_xlabel("")
            twin_ax.set_ylabel("species relative abundance")
            twin_ax.set_ylim(0)
            lib.plot.rotate_xticklabels(ax)
            ax.set_xlim(-0.5, 14)
        fig.tight_layout()
    
        if savefig:
            _outpath = _outpath_pattern.format(_species_name=_species_name)
            print(_outpath)
            fig.savefig(_outpath)

In [None]:
def compete_two_fits(_world0, _world1, plot_npos=1000, low_abund_thresh=0.05):
    w0 = (
        _world0.random_sample(position=min(_world0.sizes["position"], plot_npos))
        .drop_low_abundance_strains(low_abund_thresh)
    )
    w1 = _world1.sel(position=w0.position, sample=w0.sample).drop_low_abundance_strains(
        0.05
    )

    sf.plot.plot_metagenotype(
        w0, col_linkage_func=lambda w: w0.metagenotype.linkage(optimal_ordering=True)
    )
    sf.plot.plot_community(
        w0,
        col_linkage_func=lambda w: w0.metagenotype.linkage(optimal_ordering=True),
        row_linkage_func=lambda w: w.genotype.linkage(optimal_ordering=True),
    )
    sf.plot.plot_community(
        w1,
        col_linkage_func=lambda w: w0.metagenotype.linkage(optimal_ordering=True),
        row_linkage_func=lambda w: w.genotype.linkage(optimal_ordering=True),
    )


    fig, axs = plt.subplots(2, 2)
    ax = axs[0, 0]
    ax.scatter(w0.community.entropy(), w1.community.entropy())
    ax.plot([0, 2.5], [0, 2.5])
    
    ax = axs[0, 1]
    ax.scatter(
        squareform(_world0.unifrac_pdist()), squareform(_world0.metagenotype.pdist())
    )
    ax.scatter(
        squareform(_world1.unifrac_pdist()), squareform(_world1.metagenotype.pdist())
    )

    _mgtp_pdist = _world0.metagenotype.pdist()
    _unifrac_pdist0 = _world0.unifrac_pdist()
    _unifrac_pdist1 = _world1.unifrac_pdist()

    _world0_sample_corr = {}
    _world1_sample_corr = {}
    for sample in _world0.sample.values:
        _world0_sample_corr[sample] = sp.stats.spearmanr(
            _unifrac_pdist0.loc[sample], _mgtp_pdist.loc[sample]
        )[0]
        _world1_sample_corr[sample] = sp.stats.spearmanr(
            _unifrac_pdist1.loc[sample], _mgtp_pdist.loc[sample]
        )[0]

    sample_accuracy = pd.DataFrame(
        dict(world0=_world0_sample_corr, world1=_world1_sample_corr)
    )

    ax = axs[1, 0]
    ax.scatter("world0", "world1", data=sample_accuracy)
    ax.plot([0, 1], [0, 1])

### g__Escherichia

In [None]:
_species_id = "102506"

strains_in_subjects(
    _species_rabund=species_rabund[_species_id],
    _species_id=_species_id,
    _species_taxonomy=species_taxonomy,
    _world=sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts37-s85-seed0.world.nc"
    ),
    _meta=meta_all,
)

# compete_two_fits(
#     _world0 = sf.data.World.load(
#         f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts50-s85-seed0.world.nc"
#     ),
#     _world1 = sf.data.World.load(
#         f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts48-s85-seed0.world.nc"
#     ),
# )

### s__Eggerthella lenta

In [None]:
_species_id = "102544"

strains_in_subjects(
    _species_rabund=species_rabund[_species_id],
    _species_id=_species_id,
    _species_taxonomy=species_taxonomy,
    _world=sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts37-s85-seed0.world.nc"
    ),
    _meta=meta_all,
)

# compete_two_fits(
#     _world0 = sf.data.World.load(
#         f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts50-s85-seed0.world.nc"
#     ),
#     _world1 = sf.data.World.load(
#         f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts48-s85-seed0.world.nc"
#     ),
# )

### g__Dorea scindens

In [None]:
_species_id = "101303"

strains_in_subjects(
    _species_rabund=species_rabund[_species_id],
    _species_id=_species_id,
    _species_taxonomy=species_taxonomy,
    _world=sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts37-s85-seed0.world.nc"
    ),
    _meta=meta_all,
)

# compete_two_fits(
#     _world0 = sf.data.World.load(
#         f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts50-s85-seed0.world.nc"
#     ),
#     _world1 = sf.data.World.load(
#         f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts48-s85-seed0.world.nc"
#     ),
# )

### s__Hungatella hathewayi

In [None]:
_species_id = "100150"

strains_in_subjects(
    _species_rabund=species_rabund[_species_id],
    _species_id=_species_id,
    _species_taxonomy=species_taxonomy,
    _world=sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts37-s85-seed0.world.nc"
    ),
    _meta=meta_all,
)

# compete_two_fits(
#     _world0 = sf.data.World.load(
#         f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts50-s85-seed0.world.nc"
#     ),
#     _world1 = sf.data.World.load(
#         f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts48-s85-seed0.world.nc"
#     ),
# )

### s__Eisenbergiella tayi

In [None]:
_species_id = "102330"

strains_in_subjects(
    _species_rabund=species_rabund[_species_id],
    _species_id=_species_id,
    _species_taxonomy=species_taxonomy,
    _world=sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts37-s85-seed0.world.nc"
    ),
    _meta=meta_all,
)

# compete_two_fits(
#     _world0 = sf.data.World.load(
#         f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts50-s85-seed0.world.nc"
#     ),
#     _world1 = sf.data.World.load(
#         f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts48-s85-seed0.world.nc"
#     ),
# )

### s__Blautia hansenii

In [None]:
_species_id = "101704"

strains_in_subjects(
    _species_rabund=species_rabund[_species_id],
    _species_id=_species_id,
    _species_taxonomy=species_taxonomy,
    _world=sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts37-s85-seed0.world.nc"
    ),
    _meta=meta_all,
)

# compete_two_fits(
#     _world0 = sf.data.World.load(
#         f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts50-s85-seed0.world.nc"
#     ),
#     _world1 = sf.data.World.load(
#         f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts48-s85-seed0.world.nc"
#     ),
# )

## Other Species

### s__Flavonifractor plautii

In [None]:
_species_id = "100099"

strains_in_subjects(
    _species_rabund=species_rabund[_species_id],
    _species_id=_species_id,
    _species_taxonomy=species_taxonomy,
    _world=sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts37-s85-seed0.world.nc"
    ),
    _meta=meta_all,
)

# compete_two_fits(
#     _world0 = sf.data.World.load(
#         f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts50-s85-seed0.world.nc"
#     ),
#     _world1 = sf.data.World.load(
#         f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts48-s85-seed0.world.nc"
#     ),
# )

### s__Ruminococcus_B gnavus

In [None]:
_species_id = "101380"

strains_in_subjects(
    _species_rabund=species_rabund[_species_id],
    _species_id=_species_id,
    _species_taxonomy=species_taxonomy,
    _world=sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts50-s85-seed0.world.nc"
    ),
    _meta=meta_all,
)

# compete_two_fits(
#     _world0 = sf.data.World.load(
#         f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts50-s85-seed0.world.nc"
#     ),
#     _world1 = sf.data.World.load(
#         f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts48-s85-seed0.world.nc"
#     ),
# )

### s__Bacteroides uniformis

In [None]:
_species_id = "101346"

strains_in_subjects(
    _species_rabund=species_rabund[_species_id],
    _species_id=_species_id,
    _species_taxonomy=species_taxonomy,
    _world=sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts37-s85-seed0.world.nc"
    ),
    _meta=meta_all,
)

# compete_two_fits(
#     _world0 = sf.data.World.load(
#         f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts50-s85-seed0.world.nc"
#     ),
#     _world1 = sf.data.World.load(
#         f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts48-s85-seed0.world.nc"
#     ),
# )

### s__Bacteroides_B dorei

In [None]:
_species_id = "102478"

strains_in_subjects(
    _species_rabund=species_rabund[_species_id],
    _species_id=_species_id,
    _species_taxonomy=species_taxonomy,
    _world=sf.data.World.load(
        f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts37-s85-seed0.world.nc"
    ),
    _meta=meta_all,
)

# compete_two_fits(
#     _world0 = sf.data.World.load(
#         f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts50-s85-seed0.world.nc"
#     ),
#     _world1 = sf.data.World.load(
#         f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts48-s85-seed0.world.nc"
#     ),
# )

## Permutation Test Prototype #3

This permutation test differs from earlier versions in that
- it permutes the pair labels,
- it fits a cubic spline (with 4 knots) to the time data
- it's run on the compositional data (instead of metagenomics)

In [None]:
def turnover_analysis3(
    _rabund,
    _meta,
    pair_type_palette=pair_type_palette,
    pair_type_marker_palette=pair_type_marker_palette,
    pair_type_linestyle_palette=pair_type_linestyle_palette,
    subject_palette=subject_palette,
    subject_order=subject_order,
    pair_type_order=pair_type_order,
):
    # Select data
    _rabund, _meta = lib.pandas_util.align_indexes(_rabund, _meta)

    # Calculate pairwise comparisons
    bc_cdist = pdist(_rabund, metric="braycurtis")
    time_cdist = pdist(
        _meta[["collection_date_relative_een_end"]], metric=lambda x, y: np.abs(x - y)
    )
    diff_subject_cdist = pdist(_meta[["subject_id"]], metric=lambda x, y: x != y)
    same_subject_cdist = (1 - diff_subject_cdist).astype(bool)
    type_transition_indicator = pdist(
        _meta[["sample_type"]],
        metric=lambda x, y: (x == "PostEEN").astype(float)
        + (y == "PostEEN").astype(float),
    )
    # same_type_cdist = (1 - diff_type_cdist).astype(bool)
    pairs = (
        pd.DataFrame(
            squareform(diff_subject_cdist),
            index=_meta.index,
            columns=_meta.index,
        )
        .rename_axis(index="sampleA", columns="sampleB")
        .unstack()
        .index.to_frame()
        .assign(pair=lambda x: x[["sampleA", "sampleB"]].apply(tuple, axis=1))
        .pair.unstack()
    )
    pairs = pd.Series(pairs.values[np.triu_indices_from(pairs.values, k=1)])
    d0 = pd.DataFrame(
        dict(
            sampleA=pairs.str[0],
            sampleB=pairs.str[1],
            bc=bc_cdist,
            same_subject=same_subject_cdist,
            type_transition_indicator=type_transition_indicator,
            diff_type=type_transition_indicator == 1,
            time_delta=time_cdist,
        )
    ).assign(
        pair_type=lambda x: x.type_transition_indicator.map(
            {0: "EEN", 1: "Transition", 2: "PostEEN"}
        ),
        subject_id=lambda x: x.sampleA.map(_meta.subject_id),
        subject_id_other=lambda x: x.sampleB.map(_meta.subject_id),
    )
    d1 = d0[lambda x: x.same_subject]

    # Plot and observed relationship
    fig, axs = plt.subplots(1, 2)
    ax = axs[0]
    sns.boxenplot(
        y="bc",
        x="same_subject",
        hue="pair_type",
        data=d0,
        ax=ax,
        palette=pair_type_palette,
        order=[False, True],
        hue_order=pair_type_order,
    )
    # ax.legend()
    ax = axs[1]
    sns.stripplot(
        y="bc",
        x="pair_type",
        hue="subject_id",
        data=d1,
        ax=ax,
        palette=subject_palette,
        order=pair_type_order,
    )
    ax.legend_.set_visible(False)

    # Fit observed relationship
    formula = "bc ~ 0 + C(pair_type) + cr(time_delta, 4) + C(subject_id, Sum)"
    fit = smf.ols(formula, data=d1).fit()
    observed_stat = fit.params
    observed_stat["PostEEN - EEN"] = (
        observed_stat["C(pair_type)[PostEEN]"] - observed_stat["C(pair_type)[EEN]"]
    )
    observed_stat["Transition - EEN"] = (
        observed_stat["C(pair_type)[Transition]"] - observed_stat["C(pair_type)[EEN]"]
    )
    observed_stat["Transition - PostEEN"] = (
        observed_stat["C(pair_type)[Transition]"]
        - observed_stat["C(pair_type)[PostEEN]"]
    )
    print(observed_stat[["Transition - EEN", "Transition - PostEEN", "PostEEN - EEN"]])

    # Calculate permutations
    np.random.seed(1)
    perm_stat = {}
    for i in tqdm(range(999)):
        # Permute sample-pair labels within subjects
        _perm_type = (
            d0.assign(mgen_id=lambda x: x.index)
            .groupby("subject_id")
            .pair_type.transform(np.random.permutation)
        )
        perm_fit = smf.ols(
            formula,
            data=d0.assign(
                pair_type=_perm_type,
            )[lambda x: x.same_subject],
        ).fit()
        _stat = perm_fit.params
        _stat["PostEEN - EEN"] = (
            _stat["C(pair_type)[PostEEN]"] - _stat["C(pair_type)[EEN]"]
        )
        _stat["Transition - EEN"] = (
            _stat["C(pair_type)[Transition]"] - _stat["C(pair_type)[EEN]"]
        )
        _stat["Transition - PostEEN"] = (
            _stat["C(pair_type)[Transition]"] - _stat["C(pair_type)[PostEEN]"]
        )
        perm_stat[i] = _stat
    perm_stat = pd.DataFrame(perm_stat).T

    perm_pvalues = _calculate_2tailed_pvalue_from_perm(observed_stat, perm_stat)

    fig, axs = plt.subplots(3)
    # Plot permutation tests
    ax = axs[0]
    param_name = "Transition - EEN"
    ax.set_title(param_name)
    ax.hist(perm_stat[param_name], bins=20)
    ax.axvline(observed_stat[param_name])
    ax.annotate(
        perm_pvalues[param_name],
        xy=(0.95, 0.95),
        xycoords="axes fraction",
        ha="right",
        va="top",
    )
    ax = axs[1]
    param_name = "Transition - PostEEN"
    ax.set_title(param_name)
    ax.hist(perm_stat[param_name], bins=20)
    ax.axvline(observed_stat[param_name])
    hypoth = perm_stat[param_name] > observed_stat[param_name]
    null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
    ax.annotate(
        perm_pvalues[param_name],
        xy=(0.95, 0.95),
        xycoords="axes fraction",
        ha="right",
        va="top",
    )
    ax = axs[2]
    param_name = "PostEEN - EEN"
    ax.set_title(param_name)
    ax.hist(perm_stat[param_name], bins=20)
    ax.axvline(observed_stat[param_name])
    hypoth = perm_stat[param_name] > observed_stat[param_name]
    null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
    ax.annotate(
        perm_pvalues[param_name],
        xy=(0.95, 0.95),
        xycoords="axes fraction",
        ha="right",
        va="top",
    )
    fig.tight_layout()

    d2 = d1.assign(
        # predict=fit.predict(),
        # predict_mean_subject=fit.predict(d1.assign(subject_id=_arbitrary_subject))
        # - fit.params[f"C(subject_id, Sum)[S.{_arbitrary_subject}]"],
        resid_pearson=fit.resid_pearson,
        influence=fit.get_influence().summary_frame().cooks_d,
    ).sort_values("time_delta")

    fig, ax = plt.subplots()
    art = ax.scatter(
        "time_delta",
        "resid_pearson",
        c="influence",
        data=d2,
    )
    fig.colorbar(art, label="Cook's D")
    # ax.set_ylabel("Residual BC\n(standardized)")
    ax.set_xlabel("Within-Subjects Days between Samples")
    ax.set_xscale("symlog")

    fit.summary()

    fig, ax = plt.subplots()
    ax.set_title("Within-subject Pairwise Turnover")
    for pair_type in pair_type_order:
        d3 = d1[lambda x: (x.pair_type == pair_type)]
        ax.scatter(
            "time_delta",
            "bc",
            label="__nolegend__",
            color=pair_type_palette[pair_type],
            data=d3,
            marker=pair_type_marker_palette[pair_type],
        )

    _arbitrary_subject = d1.subject_id.unique()[1]
    predict_data = pd.DataFrame(
        product(
            [_arbitrary_subject],
            ["EEN", "PostEEN", "Transition"],
            np.logspace(1.0, 2.6),
        ),
        columns=["subject_id", "pair_type", "time_delta"],
    )
    predict_data = predict_data.assign(
        prediction=fit.predict(predict_data),
        predict_mean_subject=lambda x: x.prediction
        - fit.params[f"C(subject_id, Sum)[S.{_arbitrary_subject}]"],
    )
    for pair_type in pair_type_order:
        d4 = predict_data[lambda x: x.pair_type == pair_type]
        left, right = d1[lambda x: x.pair_type == pair_type].time_delta.quantile(
            [0.05, 0.95]
        )
        ax.plot(
            "time_delta",
            "predict_mean_subject",
            label="__nolegend__",
            data=d4[lambda x: (x.time_delta > left) & (x.time_delta < right)],
            color=pair_type_palette[pair_type],
            linestyle=pair_type_linestyle_palette[pair_type],
        )
    ax.set_ylabel("Bray-Curtis Dissimilarity")
    ax.set_xlabel("Days between Samples")
    ax.set_xscale("symlog", linthresh=1e-1)

    # fig, ax = plt.subplots(figsize=(2.5, 1.5))
    for pair_type in pair_type_order:
        ax.plot(
            [],
            [],
            label=pair_type,
            color=pair_type_palette[pair_type],
            marker=pair_type_marker_palette[pair_type],
            linestyle=pair_type_linestyle_palette[pair_type],
        )
    ax.legend()

    d2 = d1.assign(
        predict=fit.predict(),
        predict_mean_subject=fit.predict(d1.assign(subject_id=_arbitrary_subject))
        - fit.params[f"C(subject_id, Sum)[S.{_arbitrary_subject}]"],
        resid_pearson=fit.resid_pearson,
        influence=fit.get_influence().summary_frame().cooks_d,
    ).sort_values("time_delta")

    fig, ax = plt.subplots()
    ax.set_title("Within-subject Pairwise Turnover")
    for subject_id, pair_type in product(subject_order, pair_type_order):
        d3 = d2[lambda x: (x.subject_id == subject_id) & (x.pair_type == pair_type)]
        ax.scatter(
            "time_delta",
            "bc",
            label="__nolegend__",
            color=subject_palette[subject_id],
            data=d3,
            marker=pair_type_marker_palette[pair_type],
            alpha=0.7,
        )

    predict_data = pd.DataFrame(
        product(
            [_arbitrary_subject],
            ["EEN", "PostEEN", "Transition"],
            np.logspace(1.1, 2.6),
        ),
        columns=["subject_id", "pair_type", "time_delta"],
    )
    predict_data = predict_data.assign(
        prediction=fit.predict(predict_data),
        predict_mean_subject=lambda x: x.prediction
        - fit.params[f"C(subject_id, Sum)[S.{_arbitrary_subject}]"],
    )

    for pair_type in pair_type_order:
        d4 = predict_data[lambda x: x.pair_type == pair_type]
        left, right = d1[lambda x: x.pair_type == pair_type].time_delta.quantile(
            [0.05, 0.95]
        )
        ax.plot(
            "time_delta",
            "predict_mean_subject",
            label="__nolegend__",
            data=d4[lambda x: (x.time_delta > left) & (x.time_delta < right)],
            color="black",
            linestyle=pair_type_linestyle_palette[pair_type],
        )
    ax.set_ylabel("Bray-Curtis Dissimilarity")
    ax.set_xlabel("Days between Samples")
    ax.set_xscale("symlog", linthresh=1e-1)

    fig, ax = plt.subplots(figsize=(5, 2))
    # Legends
    for subject_id in subject_order:
        ax.scatter(
            [], [], label=subject_id, color=subject_palette[subject_id], marker="s"
        )
    ax.legend(ncols=5, title="Subject")

    fig, ax = plt.subplots(figsize=(2.5, 1.5))
    for pair_type in pair_type_order:
        ax.plot(
            [],
            [],
            label=pair_type,
            color="black",
            marker=pair_type_marker_palette[pair_type],
            linestyle=pair_type_linestyle_palette[pair_type],
        )
    ax.legend()
    return fit.summary()

### Species Turnover

In [None]:
turnover_analysis3(
    _rabund=species_rabund, _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])]
)

### Strain Turnover

In [None]:
turnover_analysis3(
    _rabund=strain_rabund, _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])]
)

### Within Species Turnover Analysis

#### E. coli

In [None]:
_species_id = "102506"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)
single_species_strain_rabund = _world.sel(
    sample=_world.metagenotype.mean_depth() > 0.5
).community.to_pandas()

turnover_analysis3(
    _rabund=single_species_strain_rabund,
    _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])],
)

#### E. lenta

In [None]:
_species_id = "102544"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)
single_species_strain_rabund = _world.sel(
    sample=_world.metagenotype.mean_depth() > 0.5
).community.to_pandas()

turnover_analysis3(
    _rabund=single_species_strain_rabund,
    _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])],
)

#### C. scindens

In [None]:
_species_id = "101303"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)
single_species_strain_rabund = _world.sel(
    sample=_world.metagenotype.mean_depth() > 0.5
).community.to_pandas()

turnover_analysis3(
    _rabund=single_species_strain_rabund,
    _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])],
)

#### H. hathewayi

In [None]:
_species_id = "100150"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)
single_species_strain_rabund = _world.sel(
    sample=_world.metagenotype.mean_depth() > 0.5
).community.to_pandas()

turnover_analysis3(
    _rabund=single_species_strain_rabund,
    _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])],
)

#### F. plautii

In [None]:
_species_id = "100099"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)
single_species_strain_rabund = _world.sel(
    sample=_world.metagenotype.mean_depth() > 0.5
).community.to_pandas()

turnover_analysis3(
    _rabund=single_species_strain_rabund,
    _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])],
)

#### R. gnavus

In [None]:
_species_id = "101380"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)
single_species_strain_rabund = _world.sel(
    sample=_world.metagenotype.mean_depth() > 0.5
).community.to_pandas()

turnover_analysis3(
    _rabund=single_species_strain_rabund,
    _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])],
)

#### B. uniformis

In [None]:
_species_id = "101346"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)
single_species_strain_rabund = _world.sel(
    sample=_world.metagenotype.mean_depth() > 0.5
).community.to_pandas()

turnover_analysis3(
    _rabund=single_species_strain_rabund,
    _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])],
)

#### B. dorei

In [None]:
_species_id = "102478"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)
single_species_strain_rabund = _world.sel(
    sample=_world.metagenotype.mean_depth() > 0.5
).community.to_pandas()

turnover_analysis3(
    _rabund=single_species_strain_rabund,
    _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])],
)

### Unifrac instead of BC

In [None]:
def turnover_analysis6(
    _world,
    _meta,
    pair_type_palette=pair_type_palette,
    pair_type_marker_palette=pair_type_marker_palette,
    pair_type_linestyle_palette=pair_type_linestyle_palette,
    subject_palette=subject_palette,
    subject_order=subject_order,
    pair_type_order=pair_type_order,
):
    # Select data
    _sample_list = list(
        set(_meta.index)
        & set(idxwhere(_world.metagenotype.mean_depth().to_series() > 0.5))
    )
    _meta = meta.loc[_sample_list]
    _world = _world.sel(sample=_sample_list)

    # Calculate pairwise comparisons
    unif_cdist = squareform(_world.unifrac_pdist())
    time_cdist = pdist(
        _meta[["collection_date_relative_een_end"]], metric=lambda x, y: np.abs(x - y)
    )
    diff_subject_cdist = pdist(_meta[["subject_id"]], metric=lambda x, y: x != y)
    same_subject_cdist = (1 - diff_subject_cdist).astype(bool)
    type_transition_indicator = pdist(
        _meta[["sample_type"]],
        metric=lambda x, y: (x == "PostEEN").astype(float)
        + (y == "PostEEN").astype(float),
    )
    # same_type_cdist = (1 - diff_type_cdist).astype(bool)
    pairs = (
        pd.DataFrame(
            squareform(diff_subject_cdist),
            index=_meta.index,
            columns=_meta.index,
        )
        .rename_axis(index="sampleA", columns="sampleB")
        .unstack()
        .index.to_frame()
        .assign(pair=lambda x: x[["sampleA", "sampleB"]].apply(tuple, axis=1))
        .pair.unstack()
    )
    pairs = pd.Series(pairs.values[np.triu_indices_from(pairs.values, k=1)])
    d0 = pd.DataFrame(
        dict(
            sampleA=pairs.str[0],
            sampleB=pairs.str[1],
            diss=unif_cdist,
            same_subject=same_subject_cdist,
            type_transition_indicator=type_transition_indicator,
            diff_type=type_transition_indicator == 1,
            time_delta=time_cdist,
        )
    ).assign(
        pair_type=lambda x: x.type_transition_indicator.map(
            {0: "EEN", 1: "Transition", 2: "PostEEN"}
        ),
        subject_id=lambda x: x.sampleA.map(_meta.subject_id),
        subject_id_other=lambda x: x.sampleB.map(_meta.subject_id),
    )
    d1 = d0[lambda x: x.same_subject]

    # Plot and observed relationship
    fig, axs = plt.subplots(1, 2)
    ax = axs[0]
    sns.boxenplot(
        y="diss",
        x="same_subject",
        hue="pair_type",
        data=d0,
        ax=ax,
        palette=pair_type_palette,
        order=[False, True],
        hue_order=pair_type_order,
    )
    # ax.legend()
    ax = axs[1]
    sns.stripplot(
        y="diss",
        x="pair_type",
        hue="subject_id",
        data=d1,
        ax=ax,
        palette=subject_palette,
        order=pair_type_order,
    )
    ax.legend_.set_visible(False)

    # Fit observed relationship
    formula = "diss ~ 0 + C(pair_type) + cr(time_delta, 4) + C(subject_id, Sum)"
    fit = smf.ols(formula, data=d1).fit()
    observed_stat = fit.params
    observed_stat["PostEEN - EEN"] = (
        observed_stat["C(pair_type)[PostEEN]"] - observed_stat["C(pair_type)[EEN]"]
    )
    observed_stat["Transition - EEN"] = (
        observed_stat["C(pair_type)[Transition]"] - observed_stat["C(pair_type)[EEN]"]
    )
    observed_stat["Transition - PostEEN"] = (
        observed_stat["C(pair_type)[Transition]"]
        - observed_stat["C(pair_type)[PostEEN]"]
    )
    print(observed_stat[["Transition - EEN", "Transition - PostEEN", "PostEEN - EEN"]])

    # Calculate permutations
    np.random.seed(1)
    perm_stat = {}
    for i in tqdm(range(999)):
        # Permute sample-pair labels within subjects
        _perm_type = (
            d0.assign(mgen_id=lambda x: x.index)
            .groupby("subject_id")
            .pair_type.transform(np.random.permutation)
        )
        perm_fit = smf.ols(
            formula,
            data=d0.assign(
                pair_type=_perm_type,
            )[lambda x: x.same_subject],
        ).fit()
        _stat = perm_fit.params
        _stat["PostEEN - EEN"] = (
            _stat["C(pair_type)[PostEEN]"] - _stat["C(pair_type)[EEN]"]
        )
        _stat["Transition - EEN"] = (
            _stat["C(pair_type)[Transition]"] - _stat["C(pair_type)[EEN]"]
        )
        _stat["Transition - PostEEN"] = (
            _stat["C(pair_type)[Transition]"] - _stat["C(pair_type)[PostEEN]"]
        )
        perm_stat[i] = _stat
    perm_stat = pd.DataFrame(perm_stat).T

    perm_pvalues = _calculate_2tailed_pvalue_from_perm(observed_stat, perm_stat)

    fig, axs = plt.subplots(3)
    # Plot permutation tests
    ax = axs[0]
    param_name = "Transition - EEN"
    ax.set_title(param_name)
    ax.hist(perm_stat[param_name], bins=20)
    ax.axvline(observed_stat[param_name])
    ax.annotate(
        perm_pvalues[param_name],
        xy=(0.95, 0.95),
        xycoords="axes fraction",
        ha="right",
        va="top",
    )
    ax = axs[1]
    param_name = "Transition - PostEEN"
    ax.set_title(param_name)
    ax.hist(perm_stat[param_name], bins=20)
    ax.axvline(observed_stat[param_name])
    hypoth = perm_stat[param_name] > observed_stat[param_name]
    null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
    ax.annotate(
        perm_pvalues[param_name],
        xy=(0.95, 0.95),
        xycoords="axes fraction",
        ha="right",
        va="top",
    )
    ax = axs[2]
    param_name = "PostEEN - EEN"
    ax.set_title(param_name)
    ax.hist(perm_stat[param_name], bins=20)
    ax.axvline(observed_stat[param_name])
    hypoth = perm_stat[param_name] > observed_stat[param_name]
    null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
    ax.annotate(
        perm_pvalues[param_name],
        xy=(0.95, 0.95),
        xycoords="axes fraction",
        ha="right",
        va="top",
    )
    fig.tight_layout()

    d2 = d1.assign(
        # predict=fit.predict(),
        # predict_mean_subject=fit.predict(d1.assign(subject_id=_arbitrary_subject))
        # - fit.params[f"C(subject_id, Sum)[S.{_arbitrary_subject}]"],
        resid_pearson=fit.resid_pearson,
        influence=fit.get_influence().summary_frame().cooks_d,
    ).sort_values("time_delta")

    fig, ax = plt.subplots()
    art = ax.scatter(
        "time_delta",
        "resid_pearson",
        c="influence",
        data=d2,
    )
    fig.colorbar(art, label="Cook's D")
    # ax.set_ylabel("Residual BC\n(standardized)")
    ax.set_xlabel("Within-Subjects Days between Samples")
    ax.set_xscale("symlog")

    fit.summary()

    fig, ax = plt.subplots()
    ax.set_title("Within-subject Pairwise Turnover")
    for pair_type in pair_type_order:
        d3 = d1[lambda x: (x.pair_type == pair_type)]
        ax.scatter(
            "time_delta",
            "diss",
            label="__nolegend__",
            color=pair_type_palette[pair_type],
            data=d3,
            marker=pair_type_marker_palette[pair_type],
        )

    _arbitrary_subject = d1.subject_id.unique()[1]
    predict_data = pd.DataFrame(
        product(
            [_arbitrary_subject],
            ["EEN", "PostEEN", "Transition"],
            np.logspace(1.0, 2.6),
        ),
        columns=["subject_id", "pair_type", "time_delta"],
    )
    predict_data = predict_data.assign(
        prediction=fit.predict(predict_data),
        predict_mean_subject=lambda x: x.prediction
        - fit.params[f"C(subject_id, Sum)[S.{_arbitrary_subject}]"],
    )
    for pair_type in pair_type_order:
        d4 = predict_data[lambda x: x.pair_type == pair_type]
        left, right = d1[lambda x: x.pair_type == pair_type].time_delta.quantile(
            [0.05, 0.95]
        )
        ax.plot(
            "time_delta",
            "predict_mean_subject",
            label="__nolegend__",
            data=d4[lambda x: (x.time_delta > left) & (x.time_delta < right)],
            color=pair_type_palette[pair_type],
            linestyle=pair_type_linestyle_palette[pair_type],
        )
    ax.set_ylabel("Unifrac Dissimilarity")
    ax.set_xlabel("Days between Samples")
    ax.set_xscale("symlog", linthresh=1e-1)

    # fig, ax = plt.subplots(figsize=(2.5, 1.5))
    for pair_type in pair_type_order:
        ax.plot(
            [],
            [],
            label=pair_type,
            color=pair_type_palette[pair_type],
            marker=pair_type_marker_palette[pair_type],
            linestyle=pair_type_linestyle_palette[pair_type],
        )
    ax.legend()

    d2 = d1.assign(
        predict=fit.predict(),
        predict_mean_subject=fit.predict(d1.assign(subject_id=_arbitrary_subject))
        - fit.params[f"C(subject_id, Sum)[S.{_arbitrary_subject}]"],
        resid_pearson=fit.resid_pearson,
        influence=fit.get_influence().summary_frame().cooks_d,
    ).sort_values("time_delta")

    fig, ax = plt.subplots()
    ax.set_title("Within-subject Pairwise Turnover")
    for subject_id, pair_type in product(subject_order, pair_type_order):
        d3 = d2[lambda x: (x.subject_id == subject_id) & (x.pair_type == pair_type)]
        ax.scatter(
            "time_delta",
            "diss",
            label="__nolegend__",
            color=subject_palette[subject_id],
            data=d3,
            marker=pair_type_marker_palette[pair_type],
            alpha=0.7,
        )

    predict_data = pd.DataFrame(
        product(
            [_arbitrary_subject],
            ["EEN", "PostEEN", "Transition"],
            np.logspace(1.1, 2.6),
        ),
        columns=["subject_id", "pair_type", "time_delta"],
    )
    predict_data = predict_data.assign(
        prediction=fit.predict(predict_data),
        predict_mean_subject=lambda x: x.prediction
        - fit.params[f"C(subject_id, Sum)[S.{_arbitrary_subject}]"],
    )

    for pair_type in pair_type_order:
        d4 = predict_data[lambda x: x.pair_type == pair_type]
        left, right = d1[lambda x: x.pair_type == pair_type].time_delta.quantile(
            [0.05, 0.95]
        )
        ax.plot(
            "time_delta",
            "predict_mean_subject",
            label="__nolegend__",
            data=d4[lambda x: (x.time_delta > left) & (x.time_delta < right)],
            color="black",
            linestyle=pair_type_linestyle_palette[pair_type],
        )
    ax.set_ylabel("Unifrac Dissimilarity")
    ax.set_xlabel("Days between Samples")
    ax.set_xscale("symlog", linthresh=1e-1)

    fig, ax = plt.subplots(figsize=(5, 2))
    # Legends
    for subject_id in subject_order:
        ax.scatter(
            [], [], label=subject_id, color=subject_palette[subject_id], marker="s"
        )
    ax.legend(ncols=5, title="Subject")

    fig, ax = plt.subplots(figsize=(2.5, 1.5))
    for pair_type in pair_type_order:
        ax.plot(
            [],
            [],
            label=pair_type,
            color="black",
            marker=pair_type_marker_palette[pair_type],
            linestyle=pair_type_linestyle_palette[pair_type],
        )
    ax.legend()
    return fit.summary()

#### E. coli

In [None]:
_species_id = "102506"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)

turnover_analysis6(
    _world=_world,
    _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])],
)

#### E. lenta

In [None]:
_species_id = "102544"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)

turnover_analysis6(
    _world=_world,
    _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])],
)

#### C. scindens

In [None]:
_species_id = "101303"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)

turnover_analysis6(
    _world=_world,
    _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])],
)

#### H. hathewayi

In [None]:
_species_id = "100150"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)

turnover_analysis6(
    _world=_world,
    _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])],
)

#### F. plautii

In [None]:
_species_id = "100099"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)

turnover_analysis6(
    _world=_world,
    _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])],
)

#### R. gnavus

In [None]:
_species_id = "101380"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)

turnover_analysis6(
    _world=_world,
    _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])],
)

#### B. uniformis

In [None]:
_species_id = "101346"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)

turnover_analysis6(
    _world=_world,
    _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])],
)

#### B. dorei

In [None]:
_species_id = "102478"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)

turnover_analysis6(
    _world=_world,
    _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])],
)

### Using Metagenotype instead of BC Diss

In [None]:
def turnover_analysis4(
    _world,
    _meta,
    pair_type_palette=pair_type_palette,
    pair_type_marker_palette=pair_type_marker_palette,
    pair_type_linestyle_palette=pair_type_linestyle_palette,
    subject_palette=subject_palette,
    subject_order=subject_order,
    pair_type_order=pair_type_order,
):
    # Select data
    _sample_list = list(
        set(_meta.index)
        & set(idxwhere(_world.metagenotype.mean_depth().to_series() > 0.5))
    )
    _meta = meta.loc[_sample_list]
    _world = _world.sel(sample=_sample_list)

    # Calculate pairwise comparisons
    mgtp_cdist = squareform(_world.metagenotype.pdist())
    time_cdist = pdist(
        _meta[["collection_date_relative_een_end"]], metric=lambda x, y: np.abs(x - y)
    )
    diff_subject_cdist = pdist(_meta[["subject_id"]], metric=lambda x, y: x != y)
    same_subject_cdist = (1 - diff_subject_cdist).astype(bool)
    type_transition_indicator = pdist(
        _meta[["sample_type"]],
        metric=lambda x, y: (x == "PostEEN").astype(float)
        + (y == "PostEEN").astype(float),
    )
    # same_type_cdist = (1 - diff_type_cdist).astype(bool)
    pairs = (
        pd.DataFrame(
            squareform(diff_subject_cdist),
            index=_meta.index,
            columns=_meta.index,
        )
        .rename_axis(index="sampleA", columns="sampleB")
        .unstack()
        .index.to_frame()
        .assign(pair=lambda x: x[["sampleA", "sampleB"]].apply(tuple, axis=1))
        .pair.unstack()
    )
    pairs = pd.Series(pairs.values[np.triu_indices_from(pairs.values, k=1)])
    d0 = pd.DataFrame(
        dict(
            sampleA=pairs.str[0],
            sampleB=pairs.str[1],
            mdiss=mgtp_cdist,
            same_subject=same_subject_cdist,
            type_transition_indicator=type_transition_indicator,
            diff_type=type_transition_indicator == 1,
            time_delta=time_cdist,
        )
    ).assign(
        pair_type=lambda x: x.type_transition_indicator.map(
            {0: "EEN", 1: "Transition", 2: "PostEEN"}
        ),
        subject_id=lambda x: x.sampleA.map(_meta.subject_id),
        subject_id_other=lambda x: x.sampleB.map(_meta.subject_id),
    )
    d1 = d0[lambda x: x.same_subject]

    # Plot and observed relationship
    fig, axs = plt.subplots(1, 2)
    ax = axs[0]
    sns.boxenplot(
        y="mdiss",
        x="same_subject",
        hue="pair_type",
        data=d0,
        ax=ax,
        palette=pair_type_palette,
        order=[False, True],
        hue_order=pair_type_order,
    )
    # ax.legend()
    ax = axs[1]
    sns.stripplot(
        y="mdiss",
        x="pair_type",
        hue="subject_id",
        data=d1,
        ax=ax,
        palette=subject_palette,
        order=pair_type_order,
    )
    ax.legend_.set_visible(False)

    # Fit observed relationship
    formula = "mdiss ~ 0 + C(pair_type) + cr(time_delta, 4) + C(subject_id, Sum)"
    fit = smf.ols(formula, data=d1).fit()
    observed_stat = fit.params
    observed_stat["PostEEN - EEN"] = (
        observed_stat["C(pair_type)[PostEEN]"] - observed_stat["C(pair_type)[EEN]"]
    )
    observed_stat["Transition - EEN"] = (
        observed_stat["C(pair_type)[Transition]"] - observed_stat["C(pair_type)[EEN]"]
    )
    observed_stat["Transition - PostEEN"] = (
        observed_stat["C(pair_type)[Transition]"]
        - observed_stat["C(pair_type)[PostEEN]"]
    )

    # Calculate permutations
    np.random.seed(1)
    perm_stat = {}
    for i in tqdm(range(999)):
        # Permute sample-pair labels within subjects
        _perm_type = (
            d0.assign(mgen_id=lambda x: x.index)
            .groupby("subject_id")
            .pair_type.transform(np.random.permutation)
        )
        perm_fit = smf.ols(
            formula,
            data=d0.assign(
                pair_type=_perm_type,
            )[lambda x: x.same_subject],
        ).fit()
        _stat = perm_fit.params
        _stat["PostEEN - EEN"] = (
            _stat["C(pair_type)[PostEEN]"] - _stat["C(pair_type)[EEN]"]
        )
        _stat["Transition - EEN"] = (
            _stat["C(pair_type)[Transition]"] - _stat["C(pair_type)[EEN]"]
        )
        _stat["Transition - PostEEN"] = (
            _stat["C(pair_type)[Transition]"] - _stat["C(pair_type)[PostEEN]"]
        )
        perm_stat[i] = _stat
    perm_stat = pd.DataFrame(perm_stat).T

    perm_pvalues = _calculate_2tailed_pvalue_from_perm(observed_stat, perm_stat)

    fig, axs = plt.subplots(3)
    # Plot permutation tests
    ax = axs[0]
    param_name = "Transition - EEN"
    ax.set_title(param_name)
    ax.hist(perm_stat[param_name], bins=20)
    ax.axvline(observed_stat[param_name])
    ax.annotate(
        perm_pvalues[param_name],
        xy=(0.95, 0.95),
        xycoords="axes fraction",
        ha="right",
        va="top",
    )
    ax = axs[1]
    param_name = "Transition - PostEEN"
    ax.set_title(param_name)
    ax.hist(perm_stat[param_name], bins=20)
    ax.axvline(observed_stat[param_name])
    hypoth = perm_stat[param_name] > observed_stat[param_name]
    null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
    ax.annotate(
        perm_pvalues[param_name],
        xy=(0.95, 0.95),
        xycoords="axes fraction",
        ha="right",
        va="top",
    )
    ax = axs[2]
    param_name = "PostEEN - EEN"
    ax.set_title(param_name)
    ax.hist(perm_stat[param_name], bins=20)
    ax.axvline(observed_stat[param_name])
    hypoth = perm_stat[param_name] > observed_stat[param_name]
    null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
    ax.annotate(
        perm_pvalues[param_name],
        xy=(0.95, 0.95),
        xycoords="axes fraction",
        ha="right",
        va="top",
    )
    fig.tight_layout()

    d2 = d1.assign(
        # predict=fit.predict(),
        # predict_mean_subject=fit.predict(d1.assign(subject_id=_arbitrary_subject))
        # - fit.params[f"C(subject_id, Sum)[S.{_arbitrary_subject}]"],
        resid_pearson=fit.resid_pearson,
        influence=fit.get_influence().summary_frame().cooks_d,
    ).sort_values("time_delta")

    fig, ax = plt.subplots()
    art = ax.scatter(
        "time_delta",
        "resid_pearson",
        c="influence",
        data=d2,
    )
    fig.colorbar(art, label="Cook's D")
    # ax.set_ylabel("Residual BC\n(standardized)")
    ax.set_xlabel("Within-Subjects Days between Samples")
    ax.set_xscale("symlog")

    fit.summary()

    _arbitrary_subject = d1.subject_id.unique()[1]

    fig, ax = plt.subplots()
    ax.set_title("Within-subject Pairwise Turnover (Species)")
    for pair_type in pair_type_order:
        d3 = d1[lambda x: (x.pair_type == pair_type)]
        ax.scatter(
            "time_delta",
            "mdiss",
            label="__nolegend__",
            color=pair_type_palette[pair_type],
            data=d3,
            marker=pair_type_marker_palette[pair_type],
        )

    predict_data = pd.DataFrame(
        product(
            [_arbitrary_subject],
            ["EEN", "PostEEN", "Transition"],
            np.logspace(1.0, 2.6),
        ),
        columns=["subject_id", "pair_type", "time_delta"],
    )
    predict_data = predict_data.assign(
        prediction=fit.predict(predict_data),
        predict_mean_subject=lambda x: x.prediction
        - fit.params[f"C(subject_id, Sum)[S.{_arbitrary_subject}]"],
    )
    for pair_type in pair_type_order:
        d4 = predict_data[lambda x: x.pair_type == pair_type]
        left, right = d1[lambda x: x.pair_type == pair_type].time_delta.quantile(
            [0.05, 0.95]
        )
        ax.plot(
            "time_delta",
            "predict_mean_subject",
            label="__nolegend__",
            data=d4[lambda x: (x.time_delta > left) & (x.time_delta < right)],
            color=pair_type_palette[pair_type],
            linestyle=pair_type_linestyle_palette[pair_type],
        )
    ax.set_ylabel("Metagenotype Dissimilarity")
    ax.set_xlabel("Days between Samples")
    ax.set_xscale("symlog", linthresh=1e-1)

    # fig, ax = plt.subplots(figsize=(2.5, 1.5))
    for pair_type in pair_type_order:
        ax.plot(
            [],
            [],
            label=pair_type,
            color=pair_type_palette[pair_type],
            marker=pair_type_marker_palette[pair_type],
            linestyle=pair_type_linestyle_palette[pair_type],
        )
    ax.legend()

    d2 = d1.assign(
        predict=fit.predict(),
        predict_mean_subject=fit.predict(d1.assign(subject_id=_arbitrary_subject))
        - fit.params[f"C(subject_id, Sum)[S.{_arbitrary_subject}]"],
        resid_pearson=fit.resid_pearson,
        influence=fit.get_influence().summary_frame().cooks_d,
    ).sort_values("time_delta")

    fig, ax = plt.subplots()
    ax.set_title("Within-subject Pairwise Turnover")
    for subject_id, pair_type in product(subject_order, pair_type_order):
        d3 = d2[lambda x: (x.subject_id == subject_id) & (x.pair_type == pair_type)]
        ax.scatter(
            "time_delta",
            "mdiss",
            label="__nolegend__",
            color=subject_palette[subject_id],
            data=d3,
            marker=pair_type_marker_palette[pair_type],
            alpha=0.7,
        )

    predict_data = pd.DataFrame(
        product(
            [_arbitrary_subject],
            ["EEN", "PostEEN", "Transition"],
            np.logspace(1.1, 2.6),
        ),
        columns=["subject_id", "pair_type", "time_delta"],
    )
    predict_data = predict_data.assign(
        prediction=fit.predict(predict_data),
        predict_mean_subject=lambda x: x.prediction
        - fit.params[f"C(subject_id, Sum)[S.{_arbitrary_subject}]"],
    )

    for pair_type in pair_type_order:
        d4 = predict_data[lambda x: x.pair_type == pair_type]
        left, right = d1[lambda x: x.pair_type == pair_type].time_delta.quantile(
            [0.05, 0.95]
        )
        ax.plot(
            "time_delta",
            "predict_mean_subject",
            label="__nolegend__",
            data=d4[lambda x: (x.time_delta > left) & (x.time_delta < right)],
            color="black",
            linestyle=pair_type_linestyle_palette[pair_type],
        )
    ax.set_ylabel("Metagenotype Dissimilarity")
    ax.set_xlabel("Days between Samples")
    ax.set_xscale("symlog", linthresh=1e-1)

    fig, ax = plt.subplots(figsize=(5, 2))
    # Legends
    for subject_id in subject_order:
        ax.scatter(
            [], [], label=subject_id, color=subject_palette[subject_id], marker="s"
        )
    ax.legend(ncols=5, title="Subject")

    fig, ax = plt.subplots(figsize=(2.5, 1.5))
    for pair_type in pair_type_order:
        ax.plot(
            [],
            [],
            label=pair_type,
            color="black",
            marker=pair_type_marker_palette[pair_type],
            linestyle=pair_type_linestyle_palette[pair_type],
        )
    ax.legend()
    return fit.summary()

#### E. coli

In [None]:
_species_id = "102506"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)
turnover_analysis4(_world=_world, _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])])

#### E. lenta

In [None]:
_species_id = "102544"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)
turnover_analysis4(_world=_world, _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])])

#### C. scindens

In [None]:
_species_id = "101303"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)
turnover_analysis4(_world=_world, _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])])

#### H. hathewayi

In [None]:
_species_id = "100150"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)
turnover_analysis4(_world=_world, _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])])

#### F. plautii

In [None]:
_species_id = "100099"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)
turnover_analysis4(_world=_world, _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])])

#### R. gnavus

In [None]:
_species_id = "101380"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)
turnover_analysis4(_world=_world, _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])])

#### B. uniformis

In [None]:
_species_id = "101346"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)
turnover_analysis4(_world=_world, _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])])

#### B. dorei

In [None]:
_species_id = "102478"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)
turnover_analysis4(_world=_world, _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])])

## Permutation Test Prototype #5

In [None]:
def turnover_analysis5(
    _rabund,
    _meta,
    pair_type_palette=pair_type_palette,
    pair_type_marker_palette=pair_type_marker_palette,
    pair_type_linestyle_palette=pair_type_linestyle_palette,
    subject_palette=subject_palette,
    subject_order=subject_order,
    pair_type_order=pair_type_order,
):
    # Select data
    _rabund, _meta = lib.pandas_util.align_indexes(_rabund, _meta)

    # Calculate pairwise comparisons
    bc_cdist = pdist(_rabund, metric="braycurtis")
    time_cdist = pdist(
        _meta[["collection_date_relative_een_end"]], metric=lambda x, y: np.abs(x - y)
    )
    diff_subject_cdist = pdist(_meta[["subject_id"]], metric=lambda x, y: x != y)
    same_subject_cdist = (1 - diff_subject_cdist).astype(bool)
    type_transition_indicator = pdist(
        _meta[["sample_type"]],
        metric=lambda x, y: (x == "PostEEN").astype(float)
        + (y == "PostEEN").astype(float),
    )
    # same_type_cdist = (1 - diff_type_cdist).astype(bool)
    pairs = (
        pd.DataFrame(
            squareform(diff_subject_cdist),
            index=_meta.index,
            columns=_meta.index,
        )
        .rename_axis(index="sampleA", columns="sampleB")
        .unstack()
        .index.to_frame()
        .assign(pair=lambda x: x[["sampleA", "sampleB"]].apply(tuple, axis=1))
        .pair.unstack()
    )
    pairs = pd.Series(pairs.values[np.triu_indices_from(pairs.values, k=1)])
    d0 = pd.DataFrame(
        dict(
            sampleA=pairs.str[0],
            sampleB=pairs.str[1],
            bc=bc_cdist,
            same_subject=same_subject_cdist,
            type_transition_indicator=type_transition_indicator,
            diff_type=type_transition_indicator == 1,
            time_delta=time_cdist,
        )
    ).assign(
        pair_type=lambda x: x.type_transition_indicator.map(
            {0: "EEN", 1: "Transition", 2: "PostEEN"}
        ),
        subject_id=lambda x: x.sampleA.map(_meta.subject_id),
        subject_id_other=lambda x: x.sampleB.map(_meta.subject_id),
    )
    d1 = d0[lambda x: x.same_subject]

    # Plot and observed relationship
    fig, axs = plt.subplots(1, 2)
    ax = axs[0]
    sns.boxenplot(
        y="bc",
        x="same_subject",
        hue="pair_type",
        data=d0,
        ax=ax,
        palette=pair_type_palette,
        order=[False, True],
        hue_order=pair_type_order,
    )
    # ax.legend()
    ax = axs[1]
    sns.stripplot(
        y="bc",
        x="pair_type",
        hue="subject_id",
        data=d1,
        ax=ax,
        palette=subject_palette,
        order=pair_type_order,
    )
    ax.legend_.set_visible(False)

    # Fit observed relationship
    formula = "bc ~ 0 + cr(time_delta, 4) + C(subject_id, Sum) + C(pair_type):np.log(time_delta)"
    fit = smf.ols(formula, data=d1).fit()
    observed_stat = fit.params
    observed_stat["PostEEN - EEN"] = (
        observed_stat["C(pair_type)[PostEEN]:np.log(time_delta)"]
        - observed_stat["C(pair_type)[EEN]:np.log(time_delta)"]
    )
    observed_stat["Transition - EEN"] = (
        observed_stat["C(pair_type)[Transition]:np.log(time_delta)"]
        - observed_stat["C(pair_type)[EEN]:np.log(time_delta)"]
    )
    observed_stat["Transition - PostEEN"] = (
        observed_stat["C(pair_type)[Transition]:np.log(time_delta)"]
        - observed_stat["C(pair_type)[PostEEN]:np.log(time_delta)"]
    )
    print(observed_stat[["Transition - EEN", "Transition - PostEEN", "PostEEN - EEN"]])

    # Calculate permutations
    np.random.seed(1)
    perm_stat = {}
    for i in tqdm(range(999)):
        # Permute sample-pair labels within subjects
        _perm_type = (
            d0.assign(mgen_id=lambda x: x.index)
            .groupby("subject_id")
            .pair_type.transform(np.random.permutation)
        )
        perm_fit = smf.ols(
            formula,
            data=d0.assign(
                pair_type=_perm_type,
            )[lambda x: x.same_subject],
        ).fit()
        _stat = perm_fit.params
        _stat["PostEEN - EEN"] = (
            _stat["C(pair_type)[PostEEN]:np.log(time_delta)"]
            - _stat["C(pair_type)[EEN]:np.log(time_delta)"]
        )
        _stat["Transition - EEN"] = (
            _stat["C(pair_type)[Transition]:np.log(time_delta)"]
            - _stat["C(pair_type)[EEN]:np.log(time_delta)"]
        )
        _stat["Transition - PostEEN"] = (
            _stat["C(pair_type)[Transition]:np.log(time_delta)"]
            - _stat["C(pair_type)[PostEEN]:np.log(time_delta)"]
        )
        perm_stat[i] = _stat
    perm_stat = pd.DataFrame(perm_stat).T

    perm_pvalues = _calculate_2tailed_pvalue_from_perm(observed_stat, perm_stat)

    fig, axs = plt.subplots(3)
    # Plot permutation tests
    ax = axs[0]
    param_name = "Transition - EEN"
    ax.set_title(param_name)
    ax.hist(perm_stat[param_name], bins=20)
    ax.axvline(observed_stat[param_name])
    ax.annotate(
        perm_pvalues[param_name],
        xy=(0.95, 0.95),
        xycoords="axes fraction",
        ha="right",
        va="top",
    )
    ax = axs[1]
    param_name = "Transition - PostEEN"
    ax.set_title(param_name)
    ax.hist(perm_stat[param_name], bins=20)
    ax.axvline(observed_stat[param_name])
    hypoth = perm_stat[param_name] > observed_stat[param_name]
    null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
    ax.annotate(
        perm_pvalues[param_name],
        xy=(0.95, 0.95),
        xycoords="axes fraction",
        ha="right",
        va="top",
    )
    ax = axs[2]
    param_name = "PostEEN - EEN"
    ax.set_title(param_name)
    ax.hist(perm_stat[param_name], bins=20)
    ax.axvline(observed_stat[param_name])
    hypoth = perm_stat[param_name] > observed_stat[param_name]
    null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
    ax.annotate(
        perm_pvalues[param_name],
        xy=(0.95, 0.95),
        xycoords="axes fraction",
        ha="right",
        va="top",
    )
    fig.tight_layout()

    d2 = d1.assign(
        # predict=fit.predict(),
        # predict_mean_subject=fit.predict(d1.assign(subject_id=_arbitrary_subject))
        # - fit.params[f"C(subject_id, Sum)[S.{_arbitrary_subject}]"],
        resid_pearson=fit.resid_pearson,
        influence=fit.get_influence().summary_frame().cooks_d,
    ).sort_values("time_delta")

    fig, ax = plt.subplots()
    art = ax.scatter(
        "time_delta",
        "resid_pearson",
        c="influence",
        data=d2,
    )
    fig.colorbar(art, label="Cook's D")
    # ax.set_ylabel("Residual BC\n(standardized)")
    ax.set_xlabel("Within-Subjects Days between Samples")
    ax.set_xscale("symlog")

    fit.summary()

    fig, ax = plt.subplots()
    ax.set_title("Within-subject Pairwise Turnover")
    for pair_type in pair_type_order:
        d3 = d1[lambda x: (x.pair_type == pair_type)]
        ax.scatter(
            "time_delta",
            "bc",
            label="__nolegend__",
            color=pair_type_palette[pair_type],
            data=d3,
            marker=pair_type_marker_palette[pair_type],
        )

    _arbitrary_subject = d1.subject_id.unique()[1]
    predict_data = pd.DataFrame(
        product(
            [_arbitrary_subject],
            ["EEN", "PostEEN", "Transition"],
            np.logspace(1.0, 2.6),
        ),
        columns=["subject_id", "pair_type", "time_delta"],
    )
    predict_data = predict_data.assign(
        prediction=fit.predict(predict_data),
        predict_mean_subject=lambda x: x.prediction
        - fit.params[f"C(subject_id, Sum)[S.{_arbitrary_subject}]"],
    )
    for pair_type in pair_type_order:
        d4 = predict_data[lambda x: x.pair_type == pair_type]
        left, right = d1[lambda x: x.pair_type == pair_type].time_delta.quantile(
            [0.05, 0.95]
        )
        ax.plot(
            "time_delta",
            "predict_mean_subject",
            label="__nolegend__",
            data=d4[lambda x: (x.time_delta > left) & (x.time_delta < right)],
            color=pair_type_palette[pair_type],
            linestyle=pair_type_linestyle_palette[pair_type],
        )
    ax.set_ylabel("Bray-Curtis Dissimilarity")
    ax.set_xlabel("Days between Samples")
    ax.set_xscale("symlog", linthresh=1e-1)

    # fig, ax = plt.subplots(figsize=(2.5, 1.5))
    for pair_type in pair_type_order:
        ax.plot(
            [],
            [],
            label=pair_type,
            color=pair_type_palette[pair_type],
            marker=pair_type_marker_palette[pair_type],
            linestyle=pair_type_linestyle_palette[pair_type],
        )
    ax.legend()

    d2 = d1.assign(
        predict=fit.predict(),
        predict_mean_subject=fit.predict(d1.assign(subject_id=_arbitrary_subject))
        - fit.params[f"C(subject_id, Sum)[S.{_arbitrary_subject}]"],
        resid_pearson=fit.resid_pearson,
        influence=fit.get_influence().summary_frame().cooks_d,
    ).sort_values("time_delta")

    fig, ax = plt.subplots()
    ax.set_title("Within-subject Pairwise Turnover")
    for subject_id, pair_type in product(subject_order, pair_type_order):
        d3 = d2[lambda x: (x.subject_id == subject_id) & (x.pair_type == pair_type)]
        ax.scatter(
            "time_delta",
            "bc",
            label="__nolegend__",
            color=subject_palette[subject_id],
            data=d3,
            marker=pair_type_marker_palette[pair_type],
            alpha=0.7,
        )

    predict_data = pd.DataFrame(
        product(
            [_arbitrary_subject],
            ["EEN", "PostEEN", "Transition"],
            np.logspace(1.1, 2.6),
        ),
        columns=["subject_id", "pair_type", "time_delta"],
    )
    predict_data = predict_data.assign(
        prediction=fit.predict(predict_data),
        predict_mean_subject=lambda x: x.prediction
        - fit.params[f"C(subject_id, Sum)[S.{_arbitrary_subject}]"],
    )

    for pair_type in pair_type_order:
        d4 = predict_data[lambda x: x.pair_type == pair_type]
        left, right = d1[lambda x: x.pair_type == pair_type].time_delta.quantile(
            [0.05, 0.95]
        )
        ax.plot(
            "time_delta",
            "predict_mean_subject",
            label="__nolegend__",
            data=d4[lambda x: (x.time_delta > left) & (x.time_delta < right)],
            color="black",
            linestyle=pair_type_linestyle_palette[pair_type],
        )
    ax.set_ylabel("Bray-Curtis Dissimilarity")
    ax.set_xlabel("Days between Samples")
    ax.set_xscale("symlog", linthresh=1e-1)

    fig, ax = plt.subplots(figsize=(5, 2))
    # Legends
    for subject_id in subject_order:
        ax.scatter(
            [], [], label=subject_id, color=subject_palette[subject_id], marker="s"
        )
    ax.legend(ncols=5, title="Subject")

    fig, ax = plt.subplots(figsize=(2.5, 1.5))
    for pair_type in pair_type_order:
        ax.plot(
            [],
            [],
            label=pair_type,
            color="black",
            marker=pair_type_marker_palette[pair_type],
            linestyle=pair_type_linestyle_palette[pair_type],
        )
    ax.legend()
    return fit.summary()

### Species Turnover

In [None]:
turnover_analysis5(
    _rabund=species_rabund, _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])]
)

### Strain Turnover

In [None]:
turnover_analysis5(
    _rabund=strain_rabund, _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])]
)

### Within Species Turnover Analysis

#### E. coli

In [None]:
_species_id = "102506"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)
single_species_strain_rabund = _world.sel(
    sample=_world.metagenotype.mean_depth() > 0.5
).community.to_pandas()

turnover_analysis5(
    _rabund=single_species_strain_rabund,
    _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])],
)

#### E. lenta

In [None]:
_species_id = "102544"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)
single_species_strain_rabund = _world.sel(
    sample=_world.metagenotype.mean_depth() > 0.5
).community.to_pandas()

turnover_analysis5(
    _rabund=single_species_strain_rabund,
    _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])],
)

#### C. scindens

In [None]:
_species_id = "101303"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)
single_species_strain_rabund = _world.sel(
    sample=_world.metagenotype.mean_depth() > 0.5
).community.to_pandas()

turnover_analysis5(
    _rabund=single_species_strain_rabund,
    _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])],
)

#### H. hathewayi

In [None]:
_species_id = "100150"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)
single_species_strain_rabund = _world.sel(
    sample=_world.metagenotype.mean_depth() > 0.5
).community.to_pandas()

turnover_analysis5(
    _rabund=single_species_strain_rabund,
    _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])],
)

#### F. plautii

In [None]:
_species_id = "100099"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)
single_species_strain_rabund = _world.sel(
    sample=_world.metagenotype.mean_depth() > 0.5
).community.to_pandas()

turnover_analysis5(
    _rabund=single_species_strain_rabund,
    _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])],
)

#### R. gnavus

In [None]:
_species_id = "101380"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)
single_species_strain_rabund = _world.sel(
    sample=_world.metagenotype.mean_depth() > 0.5
).community.to_pandas()

turnover_analysis5(
    _rabund=single_species_strain_rabund,
    _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])],
)

#### B. uniformis

In [None]:
_species_id = "101346"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)
single_species_strain_rabund = _world.sel(
    sample=_world.metagenotype.mean_depth() > 0.5
).community.to_pandas()

turnover_analysis5(
    _rabund=single_species_strain_rabund,
    _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])],
)

#### B. dorei

In [None]:
_species_id = "102478"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)
single_species_strain_rabund = _world.sel(
    sample=_world.metagenotype.mean_depth() > 0.5
).community.to_pandas()

turnover_analysis5(
    _rabund=single_species_strain_rabund,
    _meta=meta[meta.sample_type.isin(["EEN", "PostEEN"])],
)