## Preamble

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

### Imports

In [None]:
import os
import subprocess
import time
from itertools import chain, product
from tempfile import mkstemp
from warnings import filterwarnings

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import sfacts as sf
import statsmodels.formula.api as smf
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable
from statsmodels.graphics.regressionplots import influence_plot
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
import lib.thisproject.data
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

In [None]:
def _calculate_2tailed_pvalue_from_perm(obs, perms):
    hypoth_left = perms > obs
    hypoth_right = perms < obs
    null_p_left = (hypoth_left.sum() + 1) / (len(hypoth_left) + 1)
    null_p_right = (hypoth_right.sum() + 1) / (len(hypoth_right) + 1)
    return np.minimum(null_p_left, null_p_right) * 2

In [None]:
def linkage_order(linkage, labels):
    return labels[sp.cluster.hierarchy.to_tree(linkage).pre_order(lambda x: x.id)]



def is_prime(n):
    if n <= 1:
        return False
    for i in range(2, int(n**0.5) + 1):
        if n % i == 0:
            return False
    return True


def iterate_primes_up_to(n, return_index=False):
    n = int(np.ceil(n))
    idx = 0
    for i in range(n):
        if is_prime(i):
            if return_index:
                yield (idx, i)
            else:
                yield i
            idx += 1


def maximally_shuffled_order(sorted_order):
    n = len(sorted_order)
    primes_list = list(iterate_primes_up_to(np.sqrt(n)))
    table = pd.DataFrame(np.arange(n), index=sorted_order, columns=["original_order"])
    for prime in primes_list:
        table[prime] = table.original_order % prime
    table.sort_values(primes_list).original_order.values
    table = table.assign(new_order=table.sort_values(primes_list).original_order.values)
    z = table.sort_values("new_order").original_order.values
    table["delta"] = [np.nan] + list(z[1:] - z[:-1])
    return table.sort_values("new_order").index.to_list()

In [None]:
species_depth = (
    pd.read_table(
        "data/group/een/r.proc.gtpro.species_depth.tsv",
        index_col=["sample", "species_id"],
    )
    .depth.unstack(fill_value=0)
    .rename(columns=str)
)

bins = np.linspace(0, 30_000, num=200)

fig, axs = plt.subplots(2, sharex=True)

for (title, x), ax in zip(
    dict(
        total_depth_by_sample=species_depth.sum(1),
        total_depth_by_species=species_depth.sum(0),
    ).items(),
    axs.flatten(),
):
    ax.hist(x, bins=np.logspace(-1, 5, num=100))
    ax.set_title(title)
    ax.set_xscale("log")
fig.tight_layout()

In [None]:
species_rabund = species_depth.divide(species_depth.sum(1), axis=0)
species_rabund.mean().sort_values(ascending=False).head(20)

In [None]:
species_rabund = species_depth.divide(species_depth.sum(1), axis=0)

In [None]:
n_species = 10
top_species = (
    (species_rabund > 1e-5).sum().sort_values(ascending=False).head(n_species).index
)

fig, axs = plt.subplots(
    n_species, figsize=(5, 0.3 * n_species), sharex=True, sharey=True
)

bins = np.logspace(-8, 1, num=51)

for species_id, ax in zip(top_species, axs):
    # ax.hist(rabund_subset[species_id], bins=bins, alpha=0.7)
    ax.hist(species_rabund[species_id], bins=bins, alpha=0.7)
    ax.set_xscale("log")
    prevalence = (species_rabund[species_id] > 1e-5).mean()
    ax.set_title("")
    # ax.set_xticks()
    # ax.set_yticks()
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.patch.set_alpha(0.0)
    for spine in ["left", "right", "top", "bottom"]:
        ax.spines[spine].set_visible(False)
    ax.annotate(
        f"{species_id} ({prevalence:0.0%})",
        xy=(0.05, 0.1),
        ha="left",
        xycoords="axes fraction",
    )
    ax.set_xlim(left=1e-9)
    ax.set_ylim(top=20)
    ax.axvline(1e-5, lw=1, linestyle=":", color="k")

ax.xaxis.set_visible(True)
ax.spines["bottom"].set_visible(True)
ax.set_xticks([1e-4, 1e-2, 1e-0])
ax.set_xticklabels(["0.01%", "1%", "100%"])
ax.set_xlabel("Relative Abundance")

# fig.subplots_adjust(hspace=-0.75)

In [None]:
species_taxonomy = lib.thisproject.data.load_species_taxonomy(
    "ref/gtpro/species_taxonomy_ext.tsv"
)
species_taxonomy.loc["102506", "s__"] = "s__Escherichia coli"

In [None]:
for _species_id in top_species.astype(str):
    print(_species_id, ":", species_taxonomy.taxonomy_string.loc[_species_id])

In [None]:
species_taxonomy[lambda x: x.s__.str.endswith("hansenii")]

In [None]:
for _species_id in ["102544", "102506", "101303", "100150", "102330", "101704"]:
    print(
        _species_id,
        (species_rabund[_species_id] > 0.0001).mean().round(2),
        (species_rabund[_species_id] > 0.001).mean().round(2),
        species_taxonomy.loc[_species_id].s__,
        sep="\t\t",
    )

In [None]:
n_species = 20
top_species = (
    (species_rabund > 1e-3).sum().sort_values(ascending=False).head(n_species).index
)

fig, axs = plt.subplots(
    n_species, figsize=(5, 0.3 * n_species), sharex=True, sharey=True
)

bins = np.logspace(-8, 1, num=51)

for species_id, ax in zip(top_species, axs):
    # ax.hist(rabund_subset[species_id], bins=bins, alpha=0.7)
    ax.hist(species_rabund[species_id], bins=bins, alpha=0.7)
    ax.set_xscale("log")
    prevalence = (species_rabund[species_id] > 1e-3).mean()
    ax.set_title("")
    # ax.set_xticks()
    # ax.set_yticks()
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.patch.set_alpha(0.0)
    for spine in ["left", "right", "top", "bottom"]:
        ax.spines[spine].set_visible(False)
    ax.annotate(
        f"{species_id} ({prevalence:0.0%})",
        xy=(0.05, 0.1),
        ha="left",
        xycoords="axes fraction",
    )
    ax.set_xlim(left=1e-9)
    ax.set_ylim(top=20)
    ax.axvline(1e-5, lw=1, linestyle=":", color="k")

ax.xaxis.set_visible(True)
ax.spines["bottom"].set_visible(True)
ax.set_xticks([1e-4, 1e-2, 1e-0])
ax.set_xticklabels(["0.01%", "1%", "100%"])
ax.set_xlabel("Relative Abundance")

# fig.subplots_adjust(hspace=-0.75)

In [None]:
strain_depth = []
missing_files = []
for species_id in species_depth.columns:
    path = f"data/group/een/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.comm.tsv"
    try:
        d = pd.read_table(path, index_col=["sample", "strain"]).squeeze().unstack()
    except FileNotFoundError:
        missing_files.append(path)
        d = pd.DataFrame([])
    _keep_strains = idxwhere(d.sum() > 0.05)
    d = d.reindex(index=species_depth.index, columns=_keep_strains, fill_value=0)
    d = d.assign(__other=lambda x: 1 - x.sum(1)).rename(columns={"__other": -1})
    d[d < 0] = 0
    d = d.divide(d.sum(1), axis=0)
    d = d.multiply(species_depth[species_id], axis=0)
    d = d.rename(columns=lambda s: f"{species_id}_{s}")
    strain_depth.append(d)
strain_depth = pd.concat(strain_depth, axis=1)
strain_rabund = strain_depth.divide(strain_depth.sum(1), axis=0)
len(species_depth.columns), len(missing_files)

In [None]:
plt.hist(strain_rabund.values.flatten() + 1e-10, bins=np.logspace(-10, 0))
plt.yscale("log")
plt.xscale("log")

In [None]:
for _species_id in top_species.astype(str):
    print(_species_id, ":", species_taxonomy.taxonomy_string.loc[_species_id])

In [None]:
meta = (
    pd.read_table("meta/een-mgen/stool.tsv")
    # .rename(columns={'Seq-Name': 'sample', 'CED/Patient-recoded': 'subject_id', 'sampleDate': 'date', 'Diet (=PreEEN, EEN, PostEEN)': 'sample_type'})
    # .assign(
    #     date=lambda x: pd.to_datetime(x.date),
    #     sample_type=lambda x: x.sample_type.fillna('???')
    # )
    .set_index("mgen_id")
    # .sort_values(['subject_id', 'date', 'sample_type'])
)

# FIXME: Metadata seems to include a swap in the metagenomic data of CF_11 and CF_15.
meta = meta.rename({"CF_11": "CF_15", "CF_15": "CF_11"})
meta

In [None]:
micro_meta = (
    pd.read_table("meta/een-mgen/microcosm.tsv")
    .set_index("mgen_id")
    .rename(columns={"inoculum_subject_id": "subject_id"})
)
micro_meta

In [None]:
assert meta.index.isin(species_rabund.index).all()

In [None]:
meta[["subject_id", "sample_type"]].value_counts().unstack(fill_value=0).sort_values(
    "EEN", ascending=False
).head(20)

## Load Species Metagenotype / Strain Deconvolution

In [None]:
species_id = "100099"
species_taxonomy.loc[species_id]

In [None]:
np.random.seed(0)

mgtp_all = sf.data.Metagenotype.load(
    f"data/group/een/species/sp-{species_id}/r.proc.gtpro.mgtp.nc"
)
world = sf.data.World.load(
    f"data/group/een/species/sp-{species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts37-s85-seed0.world.nc"
)
position_ss = world.random_sample(position=min(1000, world.sizes["position"])).position

strain_frac = world.drop_low_abundance_strains(0.05).community.to_pandas()
abundant_strain_list = list(strain_frac.columns)
abundant_strain_list.remove(-1)  # Drop "other" strain.
strain_palette = lib.plot.construct_ordered_palette(abundant_strain_list, cm="rainbow")

In [None]:
sf.plot.plot_metagenotype(
    world.sel(position=position_ss),
    col_linkage_func=lambda w: w.metagenotype.linkage(),
    dwidth=0.01,
    scalex=0.147,
    col_colors_func=None,
)
sf.plot.plot_community(
    world.sel(position=position_ss),
    col_linkage_func=lambda w: w.metagenotype.linkage(),
    row_linkage_func=lambda w: w.genotype.linkage(),
    dwidth=0.01,
    row_colors_func=None,
)
sf.plot.plot_genotype(
    world.sel(position=position_ss),
    row_linkage_func=lambda w: w.genotype.linkage(),
    row_colors_func=lambda w: w.genotype.entropy(),
)

In [None]:
assert (world.sample.isin(micro_meta.index) | world.sample.isin(meta.index)).all()

## Focal Species Plots

### g__Escherichia

In [None]:
_species_id = "102506"
print(species_taxonomy.loc[_species_id])
_species_name = species_taxonomy.loc[_species_id].s__[len("s__") :].replace(" ", "_")
_species_rabund = species_rabund[_species_id]
_world = sf.data.World.load(
    # f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
    # f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts49-s85-seed0.world.nc"
    f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts37-s85-seed0.world.nc"
)
_frac = _world.drop_low_abundance_strains(0.05).community.to_pandas()

# # Arbitrarily ordered palette:
# _strain_list = list(_frac.columns)
# _strain_list.remove(-1)  # Drop "other" strain.
# strain_palette = lib.plot.construct_ordered_palette(_strain_list, cm="rainbow")

# Genotype similarity ordered palette:
_world_drop_low_abund = _world.drop_low_abundance_strains(0.05)
_strain_list = list(linkage_order(
    _world_drop_low_abund.genotype.linkage(optimal_ordering=True), _world_drop_low_abund.strain.values
))
_strain_list.remove(-1)  # Drop "other" strain.
strain_palette = lib.plot.construct_ordered_palette(
    _strain_list, cm="rainbow",
)

# # Construct an ordered palette, but use the order to AVOID closely correlated strains (e.g. found in the same subject)
# # showing up with similar colors.
# _strain_list = linkage_order(
#     sp.cluster.hierarchy.linkage(
#         _frac.groupby(meta.subject_id).mean().T,
#         method="average",
#         metric="cosine",
#         optimal_ordering=True,
#     ),
#     index=_frac.columns,
# )
# _strain_list.remove(-1)  # Drop "other" strain.
# strain_palette = lib.plot.construct_ordered_palette(
#     maximally_shuffled_order(
#         _strain_list
#     ),
#     cm="rainbow",
# )

d0 = (
    pd.concat(
        [
            meta.assign(
                label=lambda x: x.assign(idx=x.index)[
                    ["idx", "collection_date_relative_een_end", "sample_type"]
                ].apply(tuple, axis=1)
            ),
            micro_meta.assign(
                collection_date_relative_een_end=np.inf,
                sample_type=lambda x: x.inoculum_mgen_id.map(lambda s: f"invitro[{s}]"),
                label=lambda x: x.assign(idx=x.index)[["idx", "sample_type"]].apply(
                    tuple, axis=1
                ),
            ),
        ]
    )
    .assign(
        # has_strain_deconv=lambda x: x.index.isin(_comm.index),
        species_rabund=_species_rabund,
    )
    .join(_frac)
    .sort_values(["subject_id", "collection_date_relative_een_end", "sample_type"])
)

subject_list = [
    "A",
    "B",
    "H",
    "C",
    "D",
    "E",
    "F",
    "G",
    "K",
    "L",
    "M",
    "N",
    "O",
    "P",
    "Q",
    "R",
    "S",
    "T",
    "U",
]
fig, axs = lib.plot.subplots_grid(
    ncols=4,
    naxes=len(subject_list),
    ax_width=5,
    ax_height=4,
)
fig.suptitle(_species_name)
for subject_id, ax in zip(subject_list, axs.flatten()):
    ax.set_title(subject_id)
    twin_ax = ax.twinx()
    d1 = d0[lambda x: x.subject_id == subject_id].set_index("label")
    d1[[-1] + _strain_list].plot(
        kind="bar",
        width=0.95,
        stacked=True,
        color=strain_palette,
        ax=ax,
    )
    d1.species_rabund.plot(kind="line", ax=twin_ax, color="k")
    ax.legend_.set_visible(False)
    ax.set_ylim(0, 1)
    ax.set_ylabel("strain fraction")
    ax.set_xlabel("")
    twin_ax.set_ylabel("species relative abundance")
    twin_ax.set_ylim(0)
    lib.plot.rotate_xticklabels(ax)
    ax.set_xlim(-0.5, 14)
fig.tight_layout()

_outpath = f"fig/een_{_species_name}_strain_tracking.pdf"
print(_outpath)
fig.savefig(_outpath)

In [None]:
print(_species_id)
_world0 = sf.data.World.load(
    f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts37-s85-seed0.world.nc"
)
_world1 = sf.data.World.load(
    f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts49-s85-seed0.world.nc"
)
w0 = (
    _world0.random_sample(position=1000)
    .sel(sample=d0.dropna(subset=[-1]).index)
    .drop_low_abundance_strains(0.05)
)
w1 = _world1.sel(position=w0.position, sample=w0.sample).drop_low_abundance_strains(
    0.05
)

plt.scatter(w0.community.entropy(), w1.community.entropy())
plt.plot([0, 1], [0, 1])

sf.plot.plot_metagenotype(w0, col_cluster=False)
sf.plot.plot_community(
    w0,
    col_cluster=False,
    row_linkage_func=lambda w: w.genotype.linkage(optimal_ordering=True),
)
sf.plot.plot_community(
    w1,
    col_cluster=False,
    row_linkage_func=lambda w: w.genotype.linkage(optimal_ordering=True),
)
sf.plot.plot_genotype(
    w0,
    row_linkage_func=lambda w: w.genotype.linkage(optimal_ordering=True),
)

### s__Eggerthella lenta

In [None]:
_species_id = "102544"
print(species_taxonomy.loc[_species_id])
_species_name = species_taxonomy.loc[_species_id].s__[len("s__") :].replace(" ", "_")
_species_rabund = species_rabund[_species_id]
_world = sf.data.World.load(
    f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
)
_frac = _world.drop_low_abundance_strains(0.05).community.to_pandas()
_strain_list = list(_frac.columns)
_strain_list.remove(-1)  # Drop "other" strain.
# # Arbitrarily ordered palette:
# strain_palette = lib.plot.construct_ordered_palette(_strain_list, cm="rainbow")

# Genotype similarity ordered palette:
strain_palette = lib.plot.construct_ordered_palette(
    linkage_order(_world.genotype.linkage(optimal_ordering=True), _world.strain), cm="rainbow"
)

# # Construct an ordered palette, but use the order to AVOID closely correlated strains showing up with similar colors.
# strain_palette = lib.plot.construct_ordered_palette(
#     maximally_shuffled_order(
#         linkage_order(
#             sp.cluster.hierarchy.linkage(
#                 _frac[_strain_list].groupby(meta.subject_id).mean().T,
#                 method="average",
#                 metric="cosine",
#                 optimal_ordering=True,
#             ),
#             index=_strain_list,
#         )
#     ),
#     cm="rainbow",
# )

d0 = (
    pd.concat(
        [
            meta.assign(
                label=lambda x: x.assign(idx=x.index)[
                    ["idx", "collection_date_relative_een_end", "sample_type"]
                ].apply(tuple, axis=1)
            ),
            micro_meta.assign(
                collection_date_relative_een_end=np.inf,
                sample_type=lambda x: x.inoculum_mgen_id.map(lambda s: f"invitro[{s}]"),
                label=lambda x: x.assign(idx=x.index)[["idx", "sample_type"]].apply(
                    tuple, axis=1
                ),
            ),
        ]
    )
    .assign(
        # has_strain_deconv=lambda x: x.index.isin(_comm.index),
        species_rabund=_species_rabund,
    )
    .join(_frac)
    .sort_values(["subject_id", "collection_date_relative_een_end", "sample_type"])
)

subject_list = [
    "A",
    "B",
    "H",
    "C",
    "D",
    "E",
    "F",
    "G",
    "K",
    "L",
    "M",
    "N",
    "O",
    "P",
    "Q",
    "R",
    "S",
    "T",
    "U",
]
fig, axs = lib.plot.subplots_grid(
    ncols=4,
    naxes=len(subject_list),
    ax_width=5,
    ax_height=4,
)
fig.suptitle(_species_name)
for subject_id, ax in zip(subject_list, axs.flatten()):
    ax.set_title(subject_id)
    twin_ax = ax.twinx()
    d1 = d0[lambda x: x.subject_id == subject_id].set_index("label")
    d1[[-1] + _strain_list].plot(
        kind="bar",
        width=0.95,
        stacked=True,
        color=strain_palette,
        ax=ax,
    )
    d1.species_rabund.plot(kind="line", ax=twin_ax, color="k")
    ax.legend_.set_visible(False)
    ax.set_ylim(0, 1)
    ax.set_ylabel("strain fraction")
    ax.set_xlabel("")
    twin_ax.set_ylabel("species relative abundance")
    twin_ax.set_ylim(0)
    lib.plot.rotate_xticklabels(ax)
    ax.set_xlim(-0.5, 14)
fig.tight_layout()

_outpath = f"fig/een_{_species_name}_strain_tracking.pdf"
print(_outpath)
fig.savefig(_outpath)

In [None]:
print(_species_id)
_world0 = sf.data.World.load(
    f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts37-s85-seed0.world.nc"
)
_world1 = sf.data.World.load(
    f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts48-s85-seed0.world.nc"
)
w0 = (
    _world0.random_sample(position=1000)
    .sel(sample=d0.dropna(subset=[-1]).index)
    .drop_low_abundance_strains(0.05)
)
w1 = _world1.sel(position=w0.position, sample=w0.sample).drop_low_abundance_strains(
    0.05
)

plt.scatter(w0.community.entropy(), w1.community.entropy())
plt.plot([0, 1], [0, 1])

sf.plot.plot_metagenotype(w0, col_cluster=False)
sf.plot.plot_community(
    w0,
    col_cluster=False,
    row_linkage_func=lambda w: w.genotype.linkage(optimal_ordering=True),
)
sf.plot.plot_community(
    w1,
    col_cluster=False,
    row_linkage_func=lambda w: w.genotype.linkage(optimal_ordering=True),
)

In [None]:
sample='CF_035'
print(_world0.metagenotype.mean_depth().sel(sample=sample).values)
sf.plot.plot_metagenotype_frequency_spectrum(_world0, sample)
# plt.yscale('log')

### g__Dorea scindens

In [None]:
_species_id = "101303"
print(species_taxonomy.loc[_species_id])
_species_name = species_taxonomy.loc[_species_id].s__[len("s__") :].replace(" ", "_")
_species_rabund = species_rabund[_species_id]
_world = sf.data.World.load(
    # f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
    f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts48-s85-seed0.world.nc"
)
_frac = _world.drop_low_abundance_strains(0.05).community.to_pandas()
_strain_list = list(_frac.columns)
_strain_list.remove(-1)  # Drop "other" strain.
# # Arbitrarily ordered palette:
# strain_palette = lib.plot.construct_ordered_palette(_strain_list, cm="rainbow")

# Genotype similarity ordered palette:
strain_palette = lib.plot.construct_ordered_palette(
    linkage_order(_world.genotype.linkage(optimal_ordering=True), _world.strain), cm="rainbow"
)

# # Construct an ordered palette, but use the order to AVOID closely correlated strains showing up with similar colors.
# strain_palette = lib.plot.construct_ordered_palette(
#     maximally_shuffled_order(
#         linkage_order(
#             sp.cluster.hierarchy.linkage(
#                 _frac[_strain_list].groupby(meta.subject_id).mean().T,
#                 method="average",
#                 metric="cosine",
#                 optimal_ordering=True,
#             ),
#             index=_strain_list,
#         )
#     ),
#     cm="rainbow",
# )

d0 = (
    pd.concat(
        [
            meta.assign(
                label=lambda x: x.assign(idx=x.index)[
                    ["idx", "collection_date_relative_een_end", "sample_type"]
                ].apply(tuple, axis=1)
            ),
            micro_meta.assign(
                collection_date_relative_een_end=np.inf,
                sample_type=lambda x: x.inoculum_mgen_id.map(lambda s: f"invitro[{s}]"),
                label=lambda x: x.assign(idx=x.index)[["idx", "sample_type"]].apply(
                    tuple, axis=1
                ),
            ),
        ]
    )
    .assign(
        # has_strain_deconv=lambda x: x.index.isin(_comm.index),
        species_rabund=_species_rabund,
    )
    .join(_frac)
    .sort_values(["subject_id", "collection_date_relative_een_end", "sample_type"])
)

subject_list = [
    "A",
    "B",
    "H",
    "C",
    "D",
    "E",
    "F",
    "G",
    "K",
    "L",
    "M",
    "N",
    "O",
    "P",
    "Q",
    "R",
    "S",
    "T",
    "U",
]
fig, axs = lib.plot.subplots_grid(
    ncols=4,
    naxes=len(subject_list),
    ax_width=5,
    ax_height=4,
)
fig.suptitle(_species_name)
for subject_id, ax in zip(subject_list, axs.flatten()):
    ax.set_title(subject_id)
    twin_ax = ax.twinx()
    d1 = d0[lambda x: x.subject_id == subject_id].set_index("label")
    d1[[-1] + _strain_list].plot(
        kind="bar",
        width=0.95,
        stacked=True,
        color=strain_palette,
        ax=ax,
    )
    d1.species_rabund.plot(kind="line", ax=twin_ax, color="k")
    ax.legend_.set_visible(False)
    ax.set_ylim(0, 1)
    ax.set_ylabel("strain fraction")
    ax.set_xlabel("")
    twin_ax.set_ylabel("species relative abundance")
    twin_ax.set_ylim(0)
    lib.plot.rotate_xticklabels(ax)
    ax.set_xlim(-0.5, 14)
fig.tight_layout()

_outpath = f"fig/een_{_species_name}_strain_tracking.pdf"
print(_outpath)
fig.savefig(_outpath)

In [None]:
_world0 = sf.data.World.load(
    f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts37-s85-seed0.world.nc"
)
_world1 = sf.data.World.load(
    f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts48-s85-seed0.world.nc"
)
w0 = (
    _world0.random_sample(position=1000)
    .sel(sample=d0.dropna(subset=[-1]).index)
    .drop_low_abundance_strains(0.05)
)
w1 = _world1.sel(position=w0.position, sample=w0.sample).drop_low_abundance_strains(
    0.05
)

fig, ax = plt.subplots()
ax.scatter(w0.community.entropy(), w1.community.entropy())
ax.plot([0, 1], [0, 1])
ax.set_aspect(1)


sf.plot.plot_metagenotype(w0, col_cluster=False)
sf.plot.plot_community(
    w0,
    col_cluster=False,
    row_linkage_func=lambda w: w.genotype.linkage(optimal_ordering=True),
)
sf.plot.plot_community(
    w1,
    col_cluster=False,
    row_linkage_func=lambda w: w.genotype.linkage(optimal_ordering=True),
)

### s__Hungatella hathewayi

In [None]:
_species_id = "101303"
print(species_taxonomy.loc[_species_id])
_species_name = species_taxonomy.loc[_species_id].s__[len("s__") :].replace(" ", "_")
_species_rabund = species_rabund[_species_id]
_world = sf.data.World.load(
    f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
)
_frac = _world.drop_low_abundance_strains(0.05).community.to_pandas()
_strain_list = list(_frac.columns)
_strain_list.remove(-1)  # Drop "other" strain.
# # Arbitrarily ordered palette:
# strain_palette = lib.plot.construct_ordered_palette(_strain_list, cm="rainbow")

# Genotype similarity ordered palette:
strain_palette = lib.plot.construct_ordered_palette(
    linkage_order(_world.genotype.linkage(optimal_ordering=True), _world.strain), cm="rainbow"
)

# # Construct an ordered palette, but use the order to AVOID closely correlated strains showing up with similar colors.
# strain_palette = lib.plot.construct_ordered_palette(
#     maximally_shuffled_order(
#         linkage_order(
#             sp.cluster.hierarchy.linkage(
#                 _frac[_strain_list].groupby(meta.subject_id).mean().T,
#                 method="average",
#                 metric="cosine",
#                 optimal_ordering=True,
#             ),
#             index=_strain_list,
#         )
#     ),
#     cm="rainbow",
# )

d0 = (
    pd.concat(
        [
            meta.assign(
                label=lambda x: x.assign(idx=x.index)[
                    ["idx", "collection_date_relative_een_end", "sample_type"]
                ].apply(tuple, axis=1)
            ),
            micro_meta.assign(
                collection_date_relative_een_end=np.inf,
                sample_type=lambda x: x.inoculum_mgen_id.map(lambda s: f"invitro[{s}]"),
                label=lambda x: x.assign(idx=x.index)[["idx", "sample_type"]].apply(
                    tuple, axis=1
                ),
            ),
        ]
    )
    .assign(
        # has_strain_deconv=lambda x: x.index.isin(_comm.index),
        species_rabund=_species_rabund,
    )
    .join(_frac)
    .sort_values(["subject_id", "collection_date_relative_een_end", "sample_type"])
)

subject_list = [
    "A",
    "B",
    "H",
    "C",
    "D",
    "E",
    "F",
    "G",
    "K",
    "L",
    "M",
    "N",
    "O",
    "P",
    "Q",
    "R",
    "S",
    "T",
    "U",
]
fig, axs = lib.plot.subplots_grid(
    ncols=4,
    naxes=len(subject_list),
    ax_width=5,
    ax_height=4,
)
fig.suptitle(_species_name)
for subject_id, ax in zip(subject_list, axs.flatten()):
    ax.set_title(subject_id)
    twin_ax = ax.twinx()
    d1 = d0[lambda x: x.subject_id == subject_id].set_index("label")
    d1[[-1] + _strain_list].plot(
        kind="bar",
        width=0.95,
        stacked=True,
        color=strain_palette,
        ax=ax,
    )
    d1.species_rabund.plot(kind="line", ax=twin_ax, color="k")
    ax.legend_.set_visible(False)
    ax.set_ylim(0, 1)
    ax.set_ylabel("strain fraction")
    ax.set_xlabel("")
    twin_ax.set_ylabel("species relative abundance")
    twin_ax.set_ylim(0)
    lib.plot.rotate_xticklabels(ax)
    ax.set_xlim(-0.5, 14)
fig.tight_layout()

_outpath = f"fig/een_{_species_name}_strain_tracking.pdf"
print(_outpath)
fig.savefig(_outpath)

### s__Eisenbergiella tayi

In [None]:
_species_id = "102330"
print(species_taxonomy.loc[_species_id])
_species_name = species_taxonomy.loc[_species_id].s__[len("s__") :].replace(" ", "_")
_species_rabund = species_rabund[_species_id]
_world = sf.data.World.load(
    f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
)
_frac = _world.drop_low_abundance_strains(0.05).community.to_pandas()
_strain_list = list(_frac.columns)
_strain_list.remove(-1)  # Drop "other" strain.
# # Arbitrarily ordered palette:
# strain_palette = lib.plot.construct_ordered_palette(_strain_list, cm="rainbow")

# Genotype similarity ordered palette:
strain_palette = lib.plot.construct_ordered_palette(
    linkage_order(_world.genotype.linkage(optimal_ordering=True), _world.strain), cm="rainbow"
)

# # Construct an ordered palette, but use the order to AVOID closely correlated strains showing up with similar colors.
# strain_palette = lib.plot.construct_ordered_palette(
#     maximally_shuffled_order(
#         linkage_order(
#             sp.cluster.hierarchy.linkage(
#                 _frac[_strain_list].groupby(meta.subject_id).mean().T,
#                 method="average",
#                 metric="cosine",
#                 optimal_ordering=True,
#             ),
#             index=_strain_list,
#         )
#     ),
#     cm="rainbow",
# )

d0 = (
    pd.concat(
        [
            meta.assign(
                label=lambda x: x.assign(idx=x.index)[
                    ["idx", "collection_date_relative_een_end", "sample_type"]
                ].apply(tuple, axis=1)
            ),
            micro_meta.assign(
                collection_date_relative_een_end=np.inf,
                sample_type=lambda x: x.inoculum_mgen_id.map(lambda s: f"invitro[{s}]"),
                label=lambda x: x.assign(idx=x.index)[["idx", "sample_type"]].apply(
                    tuple, axis=1
                ),
            ),
        ]
    )
    .assign(
        # has_strain_deconv=lambda x: x.index.isin(_comm.index),
        species_rabund=_species_rabund,
    )
    .join(_frac)
    .sort_values(["subject_id", "collection_date_relative_een_end", "sample_type"])
)

subject_list = [
    "A",
    "B",
    "H",
    "C",
    "D",
    "E",
    "F",
    "G",
    "K",
    "L",
    "M",
    "N",
    "O",
    "P",
    "Q",
    "R",
    "S",
    "T",
    "U",
]
fig, axs = lib.plot.subplots_grid(
    ncols=4,
    naxes=len(subject_list),
    ax_width=5,
    ax_height=4,
)
fig.suptitle(_species_name)
for subject_id, ax in zip(subject_list, axs.flatten()):
    ax.set_title(subject_id)
    twin_ax = ax.twinx()
    d1 = d0[lambda x: x.subject_id == subject_id].set_index("label")
    d1[[-1] + _strain_list].plot(
        kind="bar",
        width=0.95,
        stacked=True,
        color=strain_palette,
        ax=ax,
    )
    d1.species_rabund.plot(kind="line", ax=twin_ax, color="k")
    ax.legend_.set_visible(False)
    ax.set_ylim(0, 1)
    ax.set_ylabel("strain fraction")
    ax.set_xlabel("")
    twin_ax.set_ylabel("species relative abundance")
    twin_ax.set_ylim(0)
    lib.plot.rotate_xticklabels(ax)
    ax.set_xlim(-0.5, 14)
fig.tight_layout()

_outpath = f"fig/een_{_species_name}_strain_tracking.pdf"
print(_outpath)
fig.savefig(_outpath)

### s__Blautia hansenii

In [None]:
_species_id = "101704"
print(species_taxonomy.loc[_species_id])
_species_name = species_taxonomy.loc[_species_id].s__[len("s__") :].replace(" ", "_")
_species_rabund = species_rabund[_species_id]
_world = sf.data.World.load(
    f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
)
_frac = _world.drop_low_abundance_strains(0.05).community.to_pandas()
_strain_list = list(_frac.columns)
_strain_list.remove(-1)  # Drop "other" strain.
# # Arbitrarily ordered palette:
# strain_palette = lib.plot.construct_ordered_palette(_strain_list, cm="rainbow")

# Genotype similarity ordered palette:
strain_palette = lib.plot.construct_ordered_palette(
    linkage_order(_world.genotype.linkage(optimal_ordering=True), _world.strain), cm="rainbow"
)

# # Construct an ordered palette, but use the order to AVOID closely correlated strains showing up with similar colors.
# strain_palette = lib.plot.construct_ordered_palette(
#     maximally_shuffled_order(
#         linkage_order(
#             sp.cluster.hierarchy.linkage(
#                 _frac[_strain_list].groupby(meta.subject_id).mean().T,
#                 method="average",
#                 metric="cosine",
#                 optimal_ordering=True,
#             ),
#             index=_strain_list,
#         )
#     ),
#     cm="rainbow",
# )

d0 = (
    pd.concat(
        [
            meta.assign(
                label=lambda x: x.assign(idx=x.index)[
                    ["idx", "collection_date_relative_een_end", "sample_type"]
                ].apply(tuple, axis=1)
            ),
            micro_meta.assign(
                collection_date_relative_een_end=np.inf,
                sample_type=lambda x: x.inoculum_mgen_id.map(lambda s: f"invitro[{s}]"),
                label=lambda x: x.assign(idx=x.index)[["idx", "sample_type"]].apply(
                    tuple, axis=1
                ),
            ),
        ]
    )
    .assign(
        # has_strain_deconv=lambda x: x.index.isin(_comm.index),
        species_rabund=_species_rabund,
    )
    .join(_frac)
    .sort_values(["subject_id", "collection_date_relative_een_end", "sample_type"])
)

subject_list = [
    "A",
    "B",
    "H",
    "C",
    "D",
    "E",
    "F",
    "G",
    "K",
    "L",
    "M",
    "N",
    "O",
    "P",
    "Q",
    "R",
    "S",
    "T",
    "U",
]
fig, axs = lib.plot.subplots_grid(
    ncols=4,
    naxes=len(subject_list),
    ax_width=5,
    ax_height=4,
)
fig.suptitle(_species_name)
for subject_id, ax in zip(subject_list, axs.flatten()):
    ax.set_title(subject_id)
    twin_ax = ax.twinx()
    d1 = d0[lambda x: x.subject_id == subject_id].set_index("label")
    d1[[-1] + _strain_list].plot(
        kind="bar",
        width=0.95,
        stacked=True,
        color=strain_palette,
        ax=ax,
    )
    d1.species_rabund.plot(kind="line", ax=twin_ax, color="k")
    ax.legend_.set_visible(False)
    ax.set_ylim(0, 1)
    ax.set_ylabel("strain fraction")
    ax.set_xlabel("")
    twin_ax.set_ylabel("species relative abundance")
    twin_ax.set_ylim(0)
    lib.plot.rotate_xticklabels(ax)
    ax.set_xlim(-0.5, 14)
fig.tight_layout()

_outpath = f"fig/een_{_species_name}_strain_tracking.pdf"
print(_outpath)
fig.savefig(_outpath)

## Other Species

### s__Flavonifractor plautii

In [None]:
_species_id = "100099"
print(species_taxonomy.loc[_species_id])
_species_name = species_taxonomy.loc[_species_id].s__[len("s__") :].replace(" ", "_")
_species_rabund = species_rabund[_species_id]
_world = sf.data.World.load(
    f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
    # f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts48-s85-seed0.world.nc"
)
_frac = _world.drop_low_abundance_strains(0.05).community.to_pandas()
_strain_list = list(_frac.columns)
_strain_list.remove(-1)  # Drop "other" strain.
# # Arbitrarily ordered palette:
# strain_palette = lib.plot.construct_ordered_palette(_strain_list, cm="rainbow")

# # Genotype similarity ordered palette:
# strain_palette = lib.plot.construct_ordered_palette(
#     linkage_order(_world.genotype.linkage(optimal_ordering=True), _world.strain), cm="rainbow"
# )

# Construct an ordered palette, but use the order to AVOID closely correlated strains showing up with similar colors.
strain_palette = lib.plot.construct_ordered_palette(
    maximally_shuffled_order(
        linkage_order(
            sp.cluster.hierarchy.linkage(
                _frac[_strain_list].groupby(meta.subject_id).mean().T,
                method="average",
                metric="cosine",
                optimal_ordering=True,
            ),
            index=_strain_list,
        )
    ),
    cm="rainbow",
)

d0 = (
    pd.concat(
        [
            meta.assign(
                label=lambda x: x.assign(idx=x.index)[
                    ["idx", "collection_date_relative_een_end", "sample_type"]
                ].apply(tuple, axis=1)
            ),
            micro_meta.assign(
                collection_date_relative_een_end=np.inf,
                sample_type=lambda x: x.inoculum_mgen_id.map(lambda s: f"invitro[{s}]"),
                label=lambda x: x.assign(idx=x.index)[["idx", "sample_type"]].apply(
                    tuple, axis=1
                ),
            ),
        ]
    )
    .assign(
        # has_strain_deconv=lambda x: x.index.isin(_comm.index),
        species_rabund=_species_rabund,
    )
    .join(_frac)
    .sort_values(["subject_id", "collection_date_relative_een_end", "sample_type"])
)

subject_list = [
    "A",
    "B",
    "H",
    "C",
    "D",
    "E",
    "F",
    "G",
    "K",
    "L",
    "M",
    "N",
    "O",
    "P",
    "Q",
    "R",
    "S",
    "T",
    "U",
]
fig, axs = lib.plot.subplots_grid(
    ncols=4,
    naxes=len(subject_list),
    ax_width=5,
    ax_height=4,
)
fig.suptitle(_species_name)
for subject_id, ax in zip(subject_list, axs.flatten()):
    ax.set_title(subject_id)
    twin_ax = ax.twinx()
    d1 = d0[lambda x: x.subject_id == subject_id].set_index("label")
    d1[[-1] + _strain_list].plot(
        kind="bar",
        width=0.95,
        stacked=True,
        color=strain_palette,
        ax=ax,
    )
    d1.species_rabund.plot(kind="line", ax=twin_ax, color="k")
    ax.legend_.set_visible(False)
    ax.set_ylim(0, 1)
    ax.set_ylabel("strain fraction")
    ax.set_xlabel("")
    twin_ax.set_ylabel("species relative abundance")
    twin_ax.set_ylim(0)
    lib.plot.rotate_xticklabels(ax)
    ax.set_xlim(-0.5, 14)
fig.tight_layout()

_outpath = f"fig/een_{_species_name}_strain_tracking.pdf"
print(_outpath)
fig.savefig(_outpath)

In [None]:
print(_species_id)
_world0 = sf.data.World.load(
    f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts37-s85-seed0.world.nc"
)
_world1 = sf.data.World.load(
    f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts48-s85-seed0.world.nc"
)
w0 = (
    _world0.random_sample(position=1000)
    .sel(sample=d0.dropna(subset=[-1]).index)
    .drop_low_abundance_strains(0.05)
)
w1 = _world1.sel(position=w0.position, sample=w0.sample).drop_low_abundance_strains(
    0.05
)

plt.scatter(w0.community.entropy(), w1.community.entropy())
plt.plot([0, 2.5], [0, 2.5])

sf.plot.plot_metagenotype(w0, col_cluster=False)
sf.plot.plot_community(
    w0,
    col_cluster=False,
    row_linkage_func=lambda w: w.genotype.linkage(optimal_ordering=True),
)
sf.plot.plot_community(
    w1,
    col_cluster=False,
    row_linkage_func=lambda w: w.genotype.linkage(optimal_ordering=True),
)

In [None]:
samples=['CF_022']
print(_world0.metagenotype.mean_depth().sel(sample=samples).values)
sf.plot.plot_metagenotype_frequency_spectrum_compare_samples(_world0, samples)
plt.yscale('log')

## Permutation Test

### Species Turnover

In [None]:
# Select data
_rabund, _meta = lib.pandas_util.align_indexes(
    species_rabund, meta[meta.sample_type.isin(["EEN", "PostEEN"])]
)

# Calculate pairwise comparisons
bc_cdist = pdist(_rabund, metric="braycurtis")
time_cdist = pdist(
    _meta[["collection_date_relative_een_end"]], metric=lambda x, y: np.abs(x - y)
)
diff_subject_cdist = pdist(_meta[["subject_id"]], metric=lambda x, y: x != y)
same_subject_cdist = (1 - diff_subject_cdist).astype(bool)
type_transition_indicator = pdist(
    _meta[["sample_type"]],
    metric=lambda x, y: (x == "PostEEN").astype(float) + (y == "PostEEN").astype(float),
)
# same_type_cdist = (1 - diff_type_cdist).astype(bool)
pairs = (
    pd.DataFrame(
        squareform(diff_subject_cdist),
        index=_meta.index,
        columns=_meta.index,
    )
    .rename_axis(index="sampleA", columns="sampleB")
    .unstack()
    .index.to_frame()
    .assign(pair=lambda x: x[["sampleA", "sampleB"]].apply(tuple, axis=1))
    .pair.unstack()
)
pairs = pd.Series(pairs.values[np.triu_indices_from(pairs.values, k=1)])
d0 = pd.DataFrame(
    dict(
        sampleA=pairs.str[0],
        sampleB=pairs.str[1],
        bc=bc_cdist,
        same_subject=same_subject_cdist,
        type_transition_indicator=type_transition_indicator,
        diff_type=type_transition_indicator == 1,
        time_delta=time_cdist,
    )
).assign(
    pair_type=lambda x: x.type_transition_indicator.map(
        {0: "EEN", 1: "Transition", 2: "PostEEN"}
    ),
    subject_id=lambda x: x.sampleA.map(_meta.subject_id),
    subject_id_other=lambda x: x.sampleB.map(_meta.subject_id),
)
d1 = d0[lambda x: x.same_subject]


# Plot and observed relationship
fig, axs = plt.subplots(1, 2)
ax = axs[0]
sns.boxenplot(y="bc", x="same_subject", hue="pair_type", data=d0, ax=ax)
# ax.legend()
ax = axs[1]
sns.stripplot(y="bc", x="pair_type", hue="subject_id", data=d1, ax=ax)
ax.legend_.set_visible(False)


# Fit observed relationship
formula = "bc ~ 0 + C(pair_type) + C(subject_id, Sum) + np.log(time_delta)"
fit = smf.ols(formula, data=d1).fit()
observed_stat = fit.params
observed_stat["EEN - PostEEN"] = (
    observed_stat["C(pair_type)[EEN]"] - observed_stat["C(pair_type)[PostEEN]"]
)
observed_stat["EEN - Transition"] = (
    observed_stat["C(pair_type)[EEN]"] - observed_stat["C(pair_type)[Transition]"]
)
observed_stat["Transition - PostEEN"] = (
    observed_stat["C(pair_type)[Transition]"] - observed_stat["C(pair_type)[PostEEN]"]
)

# Calculate permutations
np.random.seed(1)
perm_stat = {}
for i in tqdm(range(999)):
    # Permute labels within subjects
    _perm = (
        _meta.assign(mgen_id=lambda x: x.index)
        .groupby("subject_id")
        .mgen_id.transform(np.random.permutation)
    )
    _perm_rabund = _rabund.rename(_perm).loc[_rabund.index]
    # Calculate pairwise comparisons
    perm_bc_cdist = pdist(_perm_rabund, metric="braycurtis")
    perm_fit = smf.ols(
        formula,
        data=d0.assign(
            bc=perm_bc_cdist,
        )[lambda x: x.same_subject],
    ).fit()
    _stat = perm_fit.params
    _stat["EEN - PostEEN"] = _stat["C(pair_type)[EEN]"] - _stat["C(pair_type)[PostEEN]"]
    _stat["EEN - Transition"] = (
        _stat["C(pair_type)[EEN]"] - _stat["C(pair_type)[Transition]"]
    )
    _stat["Transition - PostEEN"] = (
        _stat["C(pair_type)[Transition]"] - _stat["C(pair_type)[PostEEN]"]
    )
    perm_stat[i] = _stat
perm_stat = pd.DataFrame(perm_stat).T

perm_pvalues = _calculate_2tailed_pvalue_from_perm(observed_stat, perm_stat)

fig, axs = plt.subplots(4)
# Plot permutation tests
ax = axs[0]
param_name = "EEN - PostEEN"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
ax = axs[1]
param_name = "EEN - Transition"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
hypoth = perm_stat[param_name] > observed_stat[param_name]
null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
ax = axs[2]
param_name = "Transition - PostEEN"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
hypoth = perm_stat[param_name] > observed_stat[param_name]
null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
ax = axs[3]
param_name = "np.log(time_delta)"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
hypoth = perm_stat[param_name] > observed_stat[param_name]
null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
fig.tight_layout()

fit.summary()

In [None]:
d2 = d1.assign(
    predict=fit.predict(),
    resid=fit.resid_pearson,
    influence=fit.get_influence().summary_frame().cooks_d,
)

fig, ax = plt.subplots()
for pair_type, d3 in d2.groupby("pair_type"):
    ax.scatter(
        "time_delta",
        "bc",
        label=pair_type,
        data=d3,
    )
ax.legend()
ax.set_ylabel("Within-Subjects\nBray-Curtis Distance")
ax.set_xlabel("Within-Subjects Days between Samples")
ax.set_xscale("symlog")

fig, ax = plt.subplots()
art = ax.scatter(
    "time_delta",
    "resid",
    c="influence",
    data=d2,
)
fig.colorbar(art, label="Cook's D")
ax.set_ylabel("Residual BC\n(standardized)")
ax.set_xlabel("Within-Subjects Days between Samples")
ax.set_xscale("symlog")

### Strain Turnover

In [None]:
# Select data
_rabund, _meta = lib.pandas_util.align_indexes(
    strain_rabund, meta[meta.sample_type.isin(["EEN", "PostEEN"])]
)

# Calculate pairwise comparisons
bc_cdist = pdist(_rabund, metric="braycurtis")
time_cdist = pdist(
    _meta[["collection_date_relative_een_end"]], metric=lambda x, y: np.abs(x - y)
)
diff_subject_cdist = pdist(_meta[["subject_id"]], metric=lambda x, y: x != y)
same_subject_cdist = (1 - diff_subject_cdist).astype(bool)
type_transition_indicator = pdist(
    _meta[["sample_type"]],
    metric=lambda x, y: (x == "PostEEN").astype(float) + (y == "PostEEN").astype(float),
)
# same_type_cdist = (1 - diff_type_cdist).astype(bool)
pairs = (
    pd.DataFrame(
        squareform(diff_subject_cdist),
        index=_meta.index,
        columns=_meta.index,
    )
    .rename_axis(index="sampleA", columns="sampleB")
    .unstack()
    .index.to_frame()
    .assign(pair=lambda x: x[["sampleA", "sampleB"]].apply(tuple, axis=1))
    .pair.unstack()
)
pairs = pd.Series(pairs.values[np.triu_indices_from(pairs.values, k=1)])
d0 = pd.DataFrame(
    dict(
        sampleA=pairs.str[0],
        sampleB=pairs.str[1],
        bc=bc_cdist,
        same_subject=same_subject_cdist,
        type_transition_indicator=type_transition_indicator,
        diff_type=type_transition_indicator == 1,
        time_delta=time_cdist,
    )
).assign(
    pair_type=lambda x: x.type_transition_indicator.map(
        {0: "EEN", 1: "Transition", 2: "PostEEN"}
    ),
    subject_id=lambda x: x.sampleA.map(_meta.subject_id),
    subject_id_other=lambda x: x.sampleB.map(_meta.subject_id),
)
d1 = d0[lambda x: x.same_subject]


# Plot and observed relationship
fig, axs = plt.subplots(1, 2)
ax = axs[0]
sns.boxenplot(y="bc", x="same_subject", hue="pair_type", data=d0, ax=ax)
# ax.legend()
ax = axs[1]
sns.stripplot(y="bc", x="pair_type", hue="subject_id", data=d1, ax=ax)
ax.legend_.set_visible(False)


# Fit observed relationship
formula = "bc ~ 0 + C(pair_type) + C(subject_id, Sum) + np.log(time_delta)"
fit = smf.ols(formula, data=d1).fit()
observed_stat = fit.params
observed_stat["EEN - PostEEN"] = (
    observed_stat["C(pair_type)[EEN]"] - observed_stat["C(pair_type)[PostEEN]"]
)
observed_stat["EEN - Transition"] = (
    observed_stat["C(pair_type)[EEN]"] - observed_stat["C(pair_type)[Transition]"]
)
observed_stat["Transition - PostEEN"] = (
    observed_stat["C(pair_type)[Transition]"] - observed_stat["C(pair_type)[PostEEN]"]
)

# Calculate permutations
np.random.seed(1)
perm_stat = {}
for i in tqdm(range(999)):
    # Permute labels within subjects
    _perm = (
        _meta.assign(mgen_id=lambda x: x.index)
        .groupby("subject_id")
        .mgen_id.transform(np.random.permutation)
    )
    _perm_rabund = _rabund.rename(_perm).loc[_rabund.index]
    # Calculate pairwise comparisons
    perm_bc_cdist = pdist(_perm_rabund, metric="braycurtis")
    perm_fit = smf.ols(
        formula,
        data=d0.assign(
            bc=perm_bc_cdist,
        )[lambda x: x.same_subject],
    ).fit()
    _stat = perm_fit.params
    _stat["EEN - PostEEN"] = _stat["C(pair_type)[EEN]"] - _stat["C(pair_type)[PostEEN]"]
    _stat["EEN - Transition"] = (
        _stat["C(pair_type)[EEN]"] - _stat["C(pair_type)[Transition]"]
    )
    _stat["Transition - PostEEN"] = (
        _stat["C(pair_type)[Transition]"] - _stat["C(pair_type)[PostEEN]"]
    )
    perm_stat[i] = _stat
perm_stat = pd.DataFrame(perm_stat).T

perm_pvalues = _calculate_2tailed_pvalue_from_perm(observed_stat, perm_stat)

fig, axs = plt.subplots(4)
# Plot permutation tests
ax = axs[0]
param_name = "EEN - PostEEN"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
ax = axs[1]
param_name = "EEN - Transition"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
hypoth = perm_stat[param_name] > observed_stat[param_name]
null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
ax = axs[2]
param_name = "Transition - PostEEN"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
hypoth = perm_stat[param_name] > observed_stat[param_name]
null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
ax = axs[3]
param_name = "np.log(time_delta)"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
hypoth = perm_stat[param_name] > observed_stat[param_name]
null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
fig.tight_layout()

fit.summary()

In [None]:
d2 = d1.assign(
    predict=fit.predict(),
    resid=fit.resid_pearson,
    influence=fit.get_influence().summary_frame().cooks_d,
)

fig, ax = plt.subplots()
for pair_type, d3 in d2.groupby("pair_type"):
    ax.scatter(
        "time_delta",
        "bc",
        label=pair_type,
        data=d3,
    )
ax.legend()
ax.set_ylabel("Within-Subjects\nBray-Curtis Distance")
ax.set_xlabel("Within-Subjects Days between Samples")
ax.set_xscale("symlog")

fig, ax = plt.subplots()
art = ax.scatter(
    "time_delta",
    "resid",
    c="influence",
    data=d2,
)
fig.colorbar(art, label="Cook's D")
ax.set_ylabel("Residual BC\n(standardized)")
ax.set_xlabel("Within-Subjects Days between Samples")
ax.set_xscale("symlog")

### Strain Turnover beyond Species

In [None]:
# Select data
_strain_rabund, _species_rabund, _meta = lib.pandas_util.align_indexes(
    strain_rabund, species_rabund, meta[meta.sample_type.isin(["EEN", "PostEEN"])]
)

# Calculate pairwise comparisons
bc_strain_cdist = pdist(_strain_rabund, metric="braycurtis")
bc_species_cdist = pdist(_species_rabund, metric="braycurtis")
time_cdist = pdist(
    _meta[["collection_date_relative_een_end"]], metric=lambda x, y: np.abs(x - y)
)
diff_subject_cdist = pdist(_meta[["subject_id"]], metric=lambda x, y: x != y)
same_subject_cdist = (1 - diff_subject_cdist).astype(bool)
type_transition_indicator = pdist(
    _meta[["sample_type"]],
    metric=lambda x, y: (x == "PostEEN").astype(float) + (y == "PostEEN").astype(float),
)
# same_type_cdist = (1 - diff_type_cdist).astype(bool)
pairs = (
    pd.DataFrame(
        squareform(diff_subject_cdist),
        index=_meta.index,
        columns=_meta.index,
    )
    .rename_axis(index="sampleA", columns="sampleB")
    .unstack()
    .index.to_frame()
    .assign(pair=lambda x: x[["sampleA", "sampleB"]].apply(tuple, axis=1))
    .pair.unstack()
)
pairs = pd.Series(pairs.values[np.triu_indices_from(pairs.values, k=1)])
d0 = pd.DataFrame(
    dict(
        sampleA=pairs.str[0],
        sampleB=pairs.str[1],
        bc_strain=bc_strain_cdist,
        bc_species=bc_species_cdist,
        bc_diff=bc_strain_cdist - bc_species_cdist,
        same_subject=same_subject_cdist,
        type_transition_indicator=type_transition_indicator,
        diff_type=type_transition_indicator == 1,
        time_delta=time_cdist,
    )
).assign(
    pair_type=lambda x: x.type_transition_indicator.map(
        {0: "EEN", 1: "Transition", 2: "PostEEN"}
    ),
    subject_id=lambda x: x.sampleA.map(_meta.subject_id),
    subject_id_other=lambda x: x.sampleB.map(_meta.subject_id),
)
d1 = d0[lambda x: x.same_subject]


# Plot and observed relationship
fig, axs = plt.subplots(1, 2)
ax = axs[0]
sns.boxenplot(y="bc_diff", x="same_subject", hue="pair_type", data=d0, ax=ax)
# ax.legend()
ax = axs[1]
sns.stripplot(y="bc_diff", x="pair_type", hue="subject_id", data=d1, ax=ax)
ax.legend_.set_visible(False)


# Fit observed relationship
formula = "bc_diff ~ 0 + C(pair_type) + C(subject_id, Sum) + time_delta"
fit = smf.ols(formula, data=d1).fit()
observed_stat = fit.params
observed_stat["EEN - PostEEN"] = (
    observed_stat["C(pair_type)[EEN]"] - observed_stat["C(pair_type)[PostEEN]"]
)
observed_stat["EEN - Transition"] = (
    observed_stat["C(pair_type)[EEN]"] - observed_stat["C(pair_type)[Transition]"]
)
observed_stat["Transition - PostEEN"] = (
    observed_stat["C(pair_type)[Transition]"] - observed_stat["C(pair_type)[PostEEN]"]
)

# Calculate permutations
np.random.seed(1)
perm_stat = {}
for i in tqdm(range(999)):
    # Permute labels within subjects
    _perm = (
        _meta.assign(mgen_id=lambda x: x.index)
        .groupby("subject_id")
        .mgen_id.transform(np.random.permutation)
    )
    _perm_strain_rabund = _strain_rabund.rename(_perm).loc[_rabund.index]
    _perm_species_rabund = _species_rabund.rename(_perm).loc[_rabund.index]

    # Calculate pairwise comparisons
    perm_bc_strain_cdist = pdist(_perm_strain_rabund, metric="braycurtis")
    perm_bc_species_cdist = pdist(_perm_species_rabund, metric="braycurtis")

    perm_fit = smf.ols(
        formula,
        data=d0.assign(
            bc_strain=perm_bc_strain_cdist,
            bc_species=perm_bc_species_cdist,
            bc_diff=perm_bc_strain_cdist - perm_bc_species_cdist,
        )[lambda x: x.same_subject],
    ).fit()
    _stat = perm_fit.params
    _stat["EEN - PostEEN"] = _stat["C(pair_type)[EEN]"] - _stat["C(pair_type)[PostEEN]"]
    _stat["EEN - Transition"] = (
        _stat["C(pair_type)[EEN]"] - _stat["C(pair_type)[Transition]"]
    )
    _stat["Transition - PostEEN"] = (
        _stat["C(pair_type)[Transition]"] - _stat["C(pair_type)[PostEEN]"]
    )
    perm_stat[i] = _stat
perm_stat = pd.DataFrame(perm_stat).T

perm_pvalues = _calculate_2tailed_pvalue_from_perm(observed_stat, perm_stat)

fig, axs = plt.subplots(4)
# Plot permutation tests
ax = axs[0]
param_name = "EEN - PostEEN"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
ax = axs[1]
param_name = "EEN - Transition"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
hypoth = perm_stat[param_name] > observed_stat[param_name]
null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
ax = axs[2]
param_name = "Transition - PostEEN"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
hypoth = perm_stat[param_name] > observed_stat[param_name]
null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
ax = axs[3]
param_name = "time_delta"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
hypoth = perm_stat[param_name] > observed_stat[param_name]
null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
fig.tight_layout()

fit.summary()

In [None]:
d2 = d1.assign(
    predict=fit.predict(),
    resid=fit.resid_pearson,
    influence=fit.get_influence().summary_frame().cooks_d,
)

fig, ax = plt.subplots()
for pair_type, d3 in d2.groupby("pair_type"):
    ax.scatter(
        "time_delta",
        "bc_diff",
        label=pair_type,
        data=d3,
    )
ax.legend()
ax.set_ylabel("Within-Subjects\nBray-Curtis Distance")
ax.set_xlabel("Within-Subjects Days between Samples")
ax.set_xscale("symlog")

fig, ax = plt.subplots()
art = ax.scatter(
    "time_delta",
    "resid",
    c="influence",
    data=d2,
)
fig.colorbar(art, label="Cook's D")
ax.set_ylabel("Residual BC\n(standardized)")
ax.set_xlabel("Within-Subjects Days between Samples")
ax.set_xscale("symlog")

In [None]:
for pair_type, d3 in d2.groupby("pair_type"):
    plt.scatter(x="bc_species", y="bc_strain", label=pair_type, data=d3)
plt.legend()
plt.xlabel("Species BC")
plt.ylabel("Strain BC")

### Within Species Turnover Analysis

#### E. coli

In [None]:
_species_id = "102506"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)
single_species_strain_rabund = _world.sel(
    sample=_world.metagenotype.mean_depth() > 0.5
).community.to_pandas()

# Select data
_rabund, _meta = lib.pandas_util.align_indexes(
    single_species_strain_rabund, meta[meta.sample_type.isin(["EEN", "PostEEN"])]
)

# Calculate pairwise comparisons
bc_cdist = pdist(_rabund, metric="braycurtis")
time_cdist = pdist(
    _meta[["collection_date_relative_een_end"]], metric=lambda x, y: np.abs(x - y)
)
diff_subject_cdist = pdist(_meta[["subject_id"]], metric=lambda x, y: x != y)
same_subject_cdist = (1 - diff_subject_cdist).astype(bool)
type_transition_indicator = pdist(
    _meta[["sample_type"]],
    metric=lambda x, y: (x == "PostEEN").astype(float) + (y == "PostEEN").astype(float),
)
# same_type_cdist = (1 - diff_type_cdist).astype(bool)
pairs = (
    pd.DataFrame(
        squareform(diff_subject_cdist),
        index=_meta.index,
        columns=_meta.index,
    )
    .rename_axis(index="sampleA", columns="sampleB")
    .unstack()
    .index.to_frame()
    .assign(pair=lambda x: x[["sampleA", "sampleB"]].apply(tuple, axis=1))
    .pair.unstack()
)
pairs = pd.Series(pairs.values[np.triu_indices_from(pairs.values, k=1)])
d0 = pd.DataFrame(
    dict(
        sampleA=pairs.str[0],
        sampleB=pairs.str[1],
        sampleA_depth=_world.metagenotype.mean_depth()
        .to_series()
        .loc[pairs.str[0]]
        .values,
        sampleB_depth=_world.metagenotype.mean_depth()
        .to_series()
        .loc[pairs.str[1]]
        .values,
        bc=bc_cdist,
        same_subject=same_subject_cdist,
        type_transition_indicator=type_transition_indicator,
        diff_type=type_transition_indicator == 1,
        time_delta=time_cdist,
    )
).assign(
    pair_type=lambda x: x.type_transition_indicator.map(
        {0: "EEN", 1: "Transition", 2: "PostEEN"}
    ),
    subject_id=lambda x: x.sampleA.map(_meta.subject_id),
    subject_id_other=lambda x: x.sampleB.map(_meta.subject_id),
    min_sample_depth=lambda x: np.minimum(x.sampleA_depth, x.sampleB_depth),
)
d1 = d0[lambda x: x.same_subject]


# Plot and observed relationship
fig, axs = plt.subplots(1, 2)
ax = axs[0]
sns.boxenplot(y="bc", x="same_subject", hue="pair_type", data=d0, ax=ax)
# ax.legend()
ax = axs[1]
sns.stripplot(y="bc", x="pair_type", hue="subject_id", data=d1, ax=ax)
ax.legend_.set_visible(False)


# Fit observed relationship
formula = "bc ~ 0 + C(pair_type) + C(subject_id, Sum) + np.log(time_delta)"
fit = smf.ols(formula, data=d1).fit()
observed_stat = fit.params
observed_stat["EEN - PostEEN"] = (
    observed_stat["C(pair_type)[EEN]"] - observed_stat["C(pair_type)[PostEEN]"]
)
observed_stat["EEN - Transition"] = (
    observed_stat["C(pair_type)[EEN]"] - observed_stat["C(pair_type)[Transition]"]
)
observed_stat["Transition - PostEEN"] = (
    observed_stat["C(pair_type)[Transition]"] - observed_stat["C(pair_type)[PostEEN]"]
)

# Calculate permutations
np.random.seed(1)
perm_stat = {}
for i in tqdm(range(999)):
    # Permute labels within subjects
    _perm = (
        _meta.assign(mgen_id=lambda x: x.index)
        .groupby("subject_id")
        .mgen_id.transform(np.random.permutation)
    )
    _perm_rabund = _rabund.rename(_perm).loc[_rabund.index]
    # Calculate pairwise comparisons
    perm_bc_cdist = pdist(_perm_rabund, metric="braycurtis")
    perm_fit = smf.ols(
        formula,
        data=d0.assign(
            bc=perm_bc_cdist,
        )[lambda x: x.same_subject],
    ).fit()
    _stat = perm_fit.params
    _stat["EEN - PostEEN"] = _stat["C(pair_type)[EEN]"] - _stat["C(pair_type)[PostEEN]"]
    _stat["EEN - Transition"] = (
        _stat["C(pair_type)[EEN]"] - _stat["C(pair_type)[Transition]"]
    )
    _stat["Transition - PostEEN"] = (
        _stat["C(pair_type)[Transition]"] - _stat["C(pair_type)[PostEEN]"]
    )
    perm_stat[i] = _stat
perm_stat = pd.DataFrame(perm_stat).T

perm_pvalues = _calculate_2tailed_pvalue_from_perm(observed_stat, perm_stat)

fig, axs = plt.subplots(4)
# Plot permutation tests
ax = axs[0]
param_name = "EEN - PostEEN"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
ax = axs[1]
param_name = "EEN - Transition"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
hypoth = perm_stat[param_name] > observed_stat[param_name]
null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
ax = axs[2]
param_name = "Transition - PostEEN"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
hypoth = perm_stat[param_name] > observed_stat[param_name]
null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
ax = axs[3]
param_name = "np.log(time_delta)"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
hypoth = perm_stat[param_name] > observed_stat[param_name]
null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
fig.tight_layout()

fit.summary()

In [None]:
d2 = d1.assign(
    predict=fit.predict(),
    resid=fit.resid_pearson,
    influence=fit.get_influence().summary_frame().cooks_d,
)

fig, ax = plt.subplots()
for pair_type, d3 in d2.groupby("pair_type"):
    ax.scatter(
        "time_delta",
        "bc",
        label=pair_type,
        data=d3,
    )
ax.legend()
ax.set_ylabel("Within-Subjects\nBray-Curtis Distance")
ax.set_xlabel("Within-Subjects Days between Samples")
ax.set_xscale("symlog")

fig, ax = plt.subplots()
art = ax.scatter(
    "time_delta",
    "resid",
    c="influence",
    data=d2,
)
fig.colorbar(art, label="Cook's D")
ax.set_ylabel("Residual BC\n(standardized)")
ax.set_xlabel("Within-Subjects Days between Samples")
ax.set_xscale("symlog")

#### E. lenta

In [None]:
_species_id = "102544"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)
single_species_strain_rabund = _world.sel(
    sample=_world.metagenotype.mean_depth() > 0.5
).community.to_pandas()

# Select data
_rabund, _meta = lib.pandas_util.align_indexes(
    single_species_strain_rabund, meta[meta.sample_type.isin(["EEN", "PostEEN"])]
)

# Calculate pairwise comparisons
bc_cdist = pdist(_rabund, metric="braycurtis")
time_cdist = pdist(
    _meta[["collection_date_relative_een_end"]], metric=lambda x, y: np.abs(x - y)
)
diff_subject_cdist = pdist(_meta[["subject_id"]], metric=lambda x, y: x != y)
same_subject_cdist = (1 - diff_subject_cdist).astype(bool)
type_transition_indicator = pdist(
    _meta[["sample_type"]],
    metric=lambda x, y: (x == "PostEEN").astype(float) + (y == "PostEEN").astype(float),
)
# same_type_cdist = (1 - diff_type_cdist).astype(bool)
pairs = (
    pd.DataFrame(
        squareform(diff_subject_cdist),
        index=_meta.index,
        columns=_meta.index,
    )
    .rename_axis(index="sampleA", columns="sampleB")
    .unstack()
    .index.to_frame()
    .assign(pair=lambda x: x[["sampleA", "sampleB"]].apply(tuple, axis=1))
    .pair.unstack()
)
pairs = pd.Series(pairs.values[np.triu_indices_from(pairs.values, k=1)])
d0 = pd.DataFrame(
    dict(
        sampleA=pairs.str[0],
        sampleB=pairs.str[1],
        sampleA_depth=_world.metagenotype.mean_depth()
        .to_series()
        .loc[pairs.str[0]]
        .values,
        sampleB_depth=_world.metagenotype.mean_depth()
        .to_series()
        .loc[pairs.str[1]]
        .values,
        bc=bc_cdist,
        same_subject=same_subject_cdist,
        type_transition_indicator=type_transition_indicator,
        diff_type=type_transition_indicator == 1,
        time_delta=time_cdist,
    )
).assign(
    pair_type=lambda x: x.type_transition_indicator.map(
        {0: "EEN", 1: "Transition", 2: "PostEEN"}
    ),
    subject_id=lambda x: x.sampleA.map(_meta.subject_id),
    subject_id_other=lambda x: x.sampleB.map(_meta.subject_id),
    min_sample_depth=lambda x: np.minimum(x.sampleA_depth, x.sampleB_depth),
)
d1 = d0[lambda x: x.same_subject]


# Plot and observed relationship
fig, axs = plt.subplots(1, 2)
ax = axs[0]
sns.boxenplot(y="bc", x="same_subject", hue="pair_type", data=d0, ax=ax)
# ax.legend()
ax = axs[1]
sns.stripplot(y="bc", x="pair_type", hue="subject_id", data=d1, ax=ax)
ax.legend_.set_visible(False)


# Fit observed relationship
formula = "bc ~ 0 + C(pair_type) + C(subject_id, Sum) + np.log(time_delta)"
fit = smf.ols(formula, data=d1).fit()
observed_stat = fit.params
observed_stat["EEN - PostEEN"] = (
    observed_stat["C(pair_type)[EEN]"] - observed_stat["C(pair_type)[PostEEN]"]
)
observed_stat["EEN - Transition"] = (
    observed_stat["C(pair_type)[EEN]"] - observed_stat["C(pair_type)[Transition]"]
)
observed_stat["Transition - PostEEN"] = (
    observed_stat["C(pair_type)[Transition]"] - observed_stat["C(pair_type)[PostEEN]"]
)

# Calculate permutations
np.random.seed(1)
perm_stat = {}
for i in tqdm(range(999)):
    # Permute labels within subjects
    _perm = (
        _meta.assign(mgen_id=lambda x: x.index)
        .groupby("subject_id")
        .mgen_id.transform(np.random.permutation)
    )
    _perm_rabund = _rabund.rename(_perm).loc[_rabund.index]
    # Calculate pairwise comparisons
    perm_bc_cdist = pdist(_perm_rabund, metric="braycurtis")
    perm_fit = smf.ols(
        formula,
        data=d0.assign(
            bc=perm_bc_cdist,
        )[lambda x: x.same_subject],
    ).fit()
    _stat = perm_fit.params
    _stat["EEN - PostEEN"] = _stat["C(pair_type)[EEN]"] - _stat["C(pair_type)[PostEEN]"]
    _stat["EEN - Transition"] = (
        _stat["C(pair_type)[EEN]"] - _stat["C(pair_type)[Transition]"]
    )
    _stat["Transition - PostEEN"] = (
        _stat["C(pair_type)[Transition]"] - _stat["C(pair_type)[PostEEN]"]
    )
    perm_stat[i] = _stat
perm_stat = pd.DataFrame(perm_stat).T

perm_pvalues = _calculate_2tailed_pvalue_from_perm(observed_stat, perm_stat)

fig, axs = plt.subplots(4)
# Plot permutation tests
ax = axs[0]
param_name = "EEN - PostEEN"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
ax = axs[1]
param_name = "EEN - Transition"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
hypoth = perm_stat[param_name] > observed_stat[param_name]
null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
ax = axs[2]
param_name = "Transition - PostEEN"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
hypoth = perm_stat[param_name] > observed_stat[param_name]
null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
ax = axs[3]
param_name = "np.log(time_delta)"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
hypoth = perm_stat[param_name] > observed_stat[param_name]
null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
fig.tight_layout()

fit.summary()

In [None]:
d2 = d1.assign(
    predict=fit.predict(),
    resid=fit.resid_pearson,
    influence=fit.get_influence().summary_frame().cooks_d,
)

fig, ax = plt.subplots()
for pair_type, d3 in d2.groupby("pair_type"):
    ax.scatter(
        "time_delta",
        "bc",
        label=pair_type,
        data=d3,
    )
ax.legend()
ax.set_ylabel("Within-Subjects\nBray-Curtis Distance")
ax.set_xlabel("Within-Subjects Days between Samples")
ax.set_xscale("symlog")

fig, ax = plt.subplots()
art = ax.scatter(
    "time_delta",
    "resid",
    c="influence",
    data=d2,
)
fig.colorbar(art, label="Cook's D")
ax.set_ylabel("Residual BC\n(standardized)")
ax.set_xlabel("Within-Subjects Days between Samples")
ax.set_xscale("symlog")

#### C. scindens

In [None]:
_species_id = "101303"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)
single_species_strain_rabund = (
    _world.sel(sample=_world.metagenotype.mean_depth() > 1.0)
    .drop_low_abundance_strains(0.05)
    .community.to_pandas()
)

# Select data
_rabund, _meta = lib.pandas_util.align_indexes(
    single_species_strain_rabund, meta[meta.sample_type.isin(["EEN", "PostEEN"])]
)

# Calculate pairwise comparisons
bc_cdist = pdist(_rabund, metric="braycurtis")
time_cdist = pdist(
    _meta[["collection_date_relative_een_end"]], metric=lambda x, y: np.abs(x - y)
)
diff_subject_cdist = pdist(_meta[["subject_id"]], metric=lambda x, y: x != y)
same_subject_cdist = (1 - diff_subject_cdist).astype(bool)
type_transition_indicator = pdist(
    _meta[["sample_type"]],
    metric=lambda x, y: (x == "PostEEN").astype(float) + (y == "PostEEN").astype(float),
)
# same_type_cdist = (1 - diff_type_cdist).astype(bool)
pairs = (
    pd.DataFrame(
        squareform(diff_subject_cdist),
        index=_meta.index,
        columns=_meta.index,
    )
    .rename_axis(index="sampleA", columns="sampleB")
    .unstack()
    .index.to_frame()
    .assign(pair=lambda x: x[["sampleA", "sampleB"]].apply(tuple, axis=1))
    .pair.unstack()
)
pairs = pd.Series(pairs.values[np.triu_indices_from(pairs.values, k=1)])
d0 = pd.DataFrame(
    dict(
        sampleA=pairs.str[0],
        sampleB=pairs.str[1],
        bc=bc_cdist,
        same_subject=same_subject_cdist,
        type_transition_indicator=type_transition_indicator,
        diff_type=type_transition_indicator == 1,
        time_delta=time_cdist,
    )
).assign(
    pair_type=lambda x: x.type_transition_indicator.map(
        {0: "EEN", 1: "Transition", 2: "PostEEN"}
    ),
    subject_id=lambda x: x.sampleA.map(_meta.subject_id),
    subject_id_other=lambda x: x.sampleB.map(_meta.subject_id),
)
d1 = d0[lambda x: x.same_subject]


# Plot and observed relationship
fig, axs = plt.subplots(1, 2)
ax = axs[0]
sns.boxenplot(y="bc", x="same_subject", hue="pair_type", data=d0, ax=ax)
# ax.legend()
ax = axs[1]
sns.stripplot(y="bc", x="pair_type", hue="subject_id", data=d1, ax=ax)
ax.legend_.set_visible(False)


# Fit observed relationship
formula = "bc ~ 0 + C(pair_type) + C(subject_id, Sum) + np.log(time_delta)"
fit = smf.ols(formula, data=d1).fit()
observed_stat = fit.params
observed_stat["EEN - PostEEN"] = (
    observed_stat["C(pair_type)[EEN]"] - observed_stat["C(pair_type)[PostEEN]"]
)
observed_stat["EEN - Transition"] = (
    observed_stat["C(pair_type)[EEN]"] - observed_stat["C(pair_type)[Transition]"]
)
observed_stat["Transition - PostEEN"] = (
    observed_stat["C(pair_type)[Transition]"] - observed_stat["C(pair_type)[PostEEN]"]
)

# Calculate permutations
np.random.seed(1)
perm_stat = {}
for i in tqdm(range(999)):
    # Permute labels within subjects
    _perm = (
        _meta.assign(mgen_id=lambda x: x.index)
        .groupby("subject_id")
        .mgen_id.transform(np.random.permutation)
    )
    _perm_rabund = _rabund.rename(_perm).loc[_rabund.index]
    # Calculate pairwise comparisons
    perm_bc_cdist = pdist(_perm_rabund, metric="braycurtis")
    perm_fit = smf.ols(
        formula,
        data=d0.assign(
            bc=perm_bc_cdist,
        )[lambda x: x.same_subject],
    ).fit()
    _stat = perm_fit.params
    _stat["EEN - PostEEN"] = _stat["C(pair_type)[EEN]"] - _stat["C(pair_type)[PostEEN]"]
    _stat["EEN - Transition"] = (
        _stat["C(pair_type)[EEN]"] - _stat["C(pair_type)[Transition]"]
    )
    _stat["Transition - PostEEN"] = (
        _stat["C(pair_type)[Transition]"] - _stat["C(pair_type)[PostEEN]"]
    )
    perm_stat[i] = _stat
perm_stat = pd.DataFrame(perm_stat).T

perm_pvalues = _calculate_2tailed_pvalue_from_perm(observed_stat, perm_stat)

fig, axs = plt.subplots(4)
# Plot permutation tests
ax = axs[0]
param_name = "EEN - PostEEN"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
ax = axs[1]
param_name = "EEN - Transition"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
hypoth = perm_stat[param_name] > observed_stat[param_name]
null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
ax = axs[2]
param_name = "Transition - PostEEN"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
hypoth = perm_stat[param_name] > observed_stat[param_name]
null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
ax = axs[3]
param_name = "np.log(time_delta)"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
hypoth = perm_stat[param_name] > observed_stat[param_name]
null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
fig.tight_layout()

fit.summary()

In [None]:
d2 = d1.assign(
    predict=fit.predict(),
    resid=fit.resid_pearson,
    influence=fit.get_influence().summary_frame().cooks_d,
)

fig, ax = plt.subplots()
for pair_type, d3 in d2.groupby("pair_type"):
    ax.scatter(
        "time_delta",
        "bc",
        label=pair_type,
        data=d3,
    )
ax.legend()
ax.set_ylabel("Within-Subjects\nBray-Curtis Distance")
ax.set_xlabel("Within-Subjects Days between Samples")
ax.set_xscale("symlog")

fig, ax = plt.subplots()
art = ax.scatter(
    "time_delta",
    "resid",
    c="influence",
    data=d2,
)
fig.colorbar(art, label="Cook's D")
ax.set_ylabel("Residual BC\n(standardized)")
ax.set_xlabel("Within-Subjects Days between Samples")
ax.set_xscale("symlog")

## Permutation Test Prototype #2

### Species Turnover

In [None]:
# Select data
_rabund, _meta = lib.pandas_util.align_indexes(
    species_rabund, meta[meta.sample_type.isin(["EEN", "PostEEN"])]
)

# Calculate pairwise comparisons
bc_cdist = pdist(_rabund, metric="braycurtis")
time_cdist = pdist(
    _meta[["collection_date_relative_een_end"]], metric=lambda x, y: np.abs(x - y)
)
diff_subject_cdist = pdist(_meta[["subject_id"]], metric=lambda x, y: x != y)
same_subject_cdist = (1 - diff_subject_cdist).astype(bool)
type_transition_indicator = pdist(
    _meta[["sample_type"]],
    metric=lambda x, y: (x == "PostEEN").astype(float) + (y == "PostEEN").astype(float),
)
# same_type_cdist = (1 - diff_type_cdist).astype(bool)
pairs = (
    pd.DataFrame(
        squareform(diff_subject_cdist),
        index=_meta.index,
        columns=_meta.index,
    )
    .rename_axis(index="sampleA", columns="sampleB")
    .unstack()
    .index.to_frame()
    .assign(pair=lambda x: x[["sampleA", "sampleB"]].apply(tuple, axis=1))
    .pair.unstack()
)
pairs = pd.Series(pairs.values[np.triu_indices_from(pairs.values, k=1)])
d0 = pd.DataFrame(
    dict(
        sampleA=pairs.str[0],
        sampleB=pairs.str[1],
        bc=bc_cdist,
        same_subject=same_subject_cdist,
        type_transition_indicator=type_transition_indicator,
        diff_type=type_transition_indicator == 1,
        time_delta=time_cdist,
    )
).assign(
    pair_type=lambda x: x.type_transition_indicator.map(
        {0: "EEN", 1: "Transition", 2: "PostEEN"}
    ),
    subject_id=lambda x: x.sampleA.map(_meta.subject_id),
    subject_id_other=lambda x: x.sampleB.map(_meta.subject_id),
)
d1 = d0[lambda x: x.same_subject]


# Plot and observed relationship
fig, axs = plt.subplots(1, 2)
ax = axs[0]
sns.boxenplot(y="bc", x="same_subject", hue="pair_type", data=d0, ax=ax)
# ax.legend()
ax = axs[1]
sns.stripplot(y="bc", x="pair_type", hue="subject_id", data=d1, ax=ax)
ax.legend_.set_visible(False)


# Fit observed relationship
formula = "bc ~ 0 + C(pair_type) + C(subject_id, Sum) + np.log(time_delta)"
fit = smf.ols(formula, data=d1).fit()
observed_stat = fit.params
observed_stat["EEN - PostEEN"] = (
    observed_stat["C(pair_type)[EEN]"] - observed_stat["C(pair_type)[PostEEN]"]
)
observed_stat["EEN - Transition"] = (
    observed_stat["C(pair_type)[EEN]"] - observed_stat["C(pair_type)[Transition]"]
)
observed_stat["Transition - PostEEN"] = (
    observed_stat["C(pair_type)[Transition]"] - observed_stat["C(pair_type)[PostEEN]"]
)

# Calculate permutations
np.random.seed(1)
perm_stat = {}
for i in tqdm(range(999)):
    # Permute sample-pair labels within subjects
    _perm_type = (
        d0.assign(mgen_id=lambda x: x.index)
        .groupby("subject_id")
        .pair_type.transform(np.random.permutation)
    )
    perm_fit = smf.ols(
        formula,
        data=d0.assign(
            pair_type=_perm_type,
        )[lambda x: x.same_subject],
    ).fit()
    _stat = perm_fit.params
    _stat["EEN - PostEEN"] = _stat["C(pair_type)[EEN]"] - _stat["C(pair_type)[PostEEN]"]
    _stat["EEN - Transition"] = (
        _stat["C(pair_type)[EEN]"] - _stat["C(pair_type)[Transition]"]
    )
    _stat["Transition - PostEEN"] = (
        _stat["C(pair_type)[Transition]"] - _stat["C(pair_type)[PostEEN]"]
    )
    perm_stat[i] = _stat
perm_stat = pd.DataFrame(perm_stat).T

perm_pvalues = _calculate_2tailed_pvalue_from_perm(observed_stat, perm_stat)

fig, axs = plt.subplots(3)
# Plot permutation tests
ax = axs[0]
param_name = "EEN - PostEEN"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
ax = axs[1]
param_name = "EEN - Transition"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
hypoth = perm_stat[param_name] > observed_stat[param_name]
null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
ax = axs[2]
param_name = "Transition - PostEEN"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
hypoth = perm_stat[param_name] > observed_stat[param_name]
null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
fig.tight_layout()

fit.summary()

### Strain Turnover

In [None]:
# Select data
_rabund, _meta = lib.pandas_util.align_indexes(
    strain_rabund, meta[meta.sample_type.isin(["EEN", "PostEEN"])]
)

# Calculate pairwise comparisons
bc_cdist = pdist(_rabund, metric="braycurtis")
time_cdist = pdist(
    _meta[["collection_date_relative_een_end"]], metric=lambda x, y: np.abs(x - y)
)
diff_subject_cdist = pdist(_meta[["subject_id"]], metric=lambda x, y: x != y)
same_subject_cdist = (1 - diff_subject_cdist).astype(bool)
type_transition_indicator = pdist(
    _meta[["sample_type"]],
    metric=lambda x, y: (x == "PostEEN").astype(float) + (y == "PostEEN").astype(float),
)
# same_type_cdist = (1 - diff_type_cdist).astype(bool)
pairs = (
    pd.DataFrame(
        squareform(diff_subject_cdist),
        index=_meta.index,
        columns=_meta.index,
    )
    .rename_axis(index="sampleA", columns="sampleB")
    .unstack()
    .index.to_frame()
    .assign(pair=lambda x: x[["sampleA", "sampleB"]].apply(tuple, axis=1))
    .pair.unstack()
)
pairs = pd.Series(pairs.values[np.triu_indices_from(pairs.values, k=1)])
d0 = pd.DataFrame(
    dict(
        sampleA=pairs.str[0],
        sampleB=pairs.str[1],
        bc=bc_cdist,
        same_subject=same_subject_cdist,
        type_transition_indicator=type_transition_indicator,
        diff_type=type_transition_indicator == 1,
        time_delta=time_cdist,
    )
).assign(
    pair_type=lambda x: x.type_transition_indicator.map(
        {0: "EEN", 1: "Transition", 2: "PostEEN"}
    ),
    subject_id=lambda x: x.sampleA.map(_meta.subject_id),
    subject_id_other=lambda x: x.sampleB.map(_meta.subject_id),
)
d1 = d0[lambda x: x.same_subject]


# Plot and observed relationship
fig, axs = plt.subplots(1, 2)
ax = axs[0]
sns.boxenplot(y="bc", x="same_subject", hue="pair_type", data=d0, ax=ax)
# ax.legend()
ax = axs[1]
sns.stripplot(y="bc", x="pair_type", hue="subject_id", data=d1, ax=ax)
ax.legend_.set_visible(False)


# Fit observed relationship
formula = "bc ~ 0 + C(pair_type) + C(subject_id, Sum) + np.log(time_delta)"
fit = smf.ols(formula, data=d1).fit()
observed_stat = fit.params
observed_stat["EEN - PostEEN"] = (
    observed_stat["C(pair_type)[EEN]"] - observed_stat["C(pair_type)[PostEEN]"]
)
observed_stat["EEN - Transition"] = (
    observed_stat["C(pair_type)[EEN]"] - observed_stat["C(pair_type)[Transition]"]
)
observed_stat["Transition - PostEEN"] = (
    observed_stat["C(pair_type)[Transition]"] - observed_stat["C(pair_type)[PostEEN]"]
)

# Calculate permutations
np.random.seed(1)
perm_stat = {}
for i in tqdm(range(999)):
    # Permute sample-pair labels within subjects
    _perm_type = (
        d0.assign(mgen_id=lambda x: x.index)
        .groupby("subject_id")
        .pair_type.transform(np.random.permutation)
    )
    perm_fit = smf.ols(
        formula,
        data=d0.assign(
            pair_type=_perm_type,
        )[lambda x: x.same_subject],
    ).fit()
    _stat = perm_fit.params
    _stat["EEN - PostEEN"] = _stat["C(pair_type)[EEN]"] - _stat["C(pair_type)[PostEEN]"]
    _stat["EEN - Transition"] = (
        _stat["C(pair_type)[EEN]"] - _stat["C(pair_type)[Transition]"]
    )
    _stat["Transition - PostEEN"] = (
        _stat["C(pair_type)[Transition]"] - _stat["C(pair_type)[PostEEN]"]
    )
    perm_stat[i] = _stat
perm_stat = pd.DataFrame(perm_stat).T

perm_pvalues = _calculate_2tailed_pvalue_from_perm(observed_stat, perm_stat)

fig, axs = plt.subplots(3)
# Plot permutation tests
ax = axs[0]
param_name = "EEN - PostEEN"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
ax = axs[1]
param_name = "EEN - Transition"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
hypoth = perm_stat[param_name] > observed_stat[param_name]
null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
ax = axs[2]
param_name = "Transition - PostEEN"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
hypoth = perm_stat[param_name] > observed_stat[param_name]
null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
fig.tight_layout()

fit.summary()

In [None]:
d2 = d1.assign(
    predict=fit.predict(),
    resid=fit.resid_pearson,
    influence=fit.get_influence().summary_frame().cooks_d,
)

fig, ax = plt.subplots()
for pair_type, d3 in d2.groupby("pair_type"):
    xx = np.linspace(10, 500)
    logxx = np.log(xx)
    yy = fit.params[f'C(pair_type)[{pair_type}]'] + logxx * fit.params[f'np.log(time_delta)']
    ax.scatter(
        "time_delta",
        "bc",
        label=pair_type,
        data=d3,
    )
    ax.plot(xx, yy)
ax.legend()
ax.set_ylabel("Within-Subjects\nBray-Curtis Distance")
ax.set_xlabel("Within-Subjects Days between Samples")
ax.set_xscale("symlog")

fig, ax = plt.subplots()
art = ax.scatter(
    "time_delta",
    "resid",
    c="influence",
    data=d2,
)
fig.colorbar(art, label="Cook's D")
ax.set_ylabel("Residual BC\n(standardized)")
ax.set_xlabel("Within-Subjects Days between Samples")
ax.set_xscale("symlog")

### Within Species Turnover Analysis

#### E. coli

In [None]:
_species_id = "102506"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)
single_species_strain_rabund = _world.sel(
    sample=_world.metagenotype.mean_depth() > 0.5
).community.to_pandas()

# Select data
_rabund, _meta = lib.pandas_util.align_indexes(
    single_species_strain_rabund, meta[meta.sample_type.isin(["EEN", "PostEEN"])]
)

# Calculate pairwise comparisons
bc_cdist = pdist(_rabund, metric="braycurtis")
time_cdist = pdist(
    _meta[["collection_date_relative_een_end"]], metric=lambda x, y: np.abs(x - y)
)
diff_subject_cdist = pdist(_meta[["subject_id"]], metric=lambda x, y: x != y)
same_subject_cdist = (1 - diff_subject_cdist).astype(bool)
type_transition_indicator = pdist(
    _meta[["sample_type"]],
    metric=lambda x, y: (x == "PostEEN").astype(float) + (y == "PostEEN").astype(float),
)
# same_type_cdist = (1 - diff_type_cdist).astype(bool)
pairs = (
    pd.DataFrame(
        squareform(diff_subject_cdist),
        index=_meta.index,
        columns=_meta.index,
    )
    .rename_axis(index="sampleA", columns="sampleB")
    .unstack()
    .index.to_frame()
    .assign(pair=lambda x: x[["sampleA", "sampleB"]].apply(tuple, axis=1))
    .pair.unstack()
)
pairs = pd.Series(pairs.values[np.triu_indices_from(pairs.values, k=1)])
d0 = pd.DataFrame(
    dict(
        sampleA=pairs.str[0],
        sampleB=pairs.str[1],
        bc=bc_cdist,
        same_subject=same_subject_cdist,
        type_transition_indicator=type_transition_indicator,
        diff_type=type_transition_indicator == 1,
        time_delta=time_cdist,
    )
).assign(
    pair_type=lambda x: x.type_transition_indicator.map(
        {0: "EEN", 1: "Transition", 2: "PostEEN"}
    ),
    subject_id=lambda x: x.sampleA.map(_meta.subject_id),
    subject_id_other=lambda x: x.sampleB.map(_meta.subject_id),
)
d1 = d0[lambda x: x.same_subject]


# Plot and observed relationship
fig, axs = plt.subplots(1, 2)
ax = axs[0]
sns.boxenplot(y="bc", x="same_subject", hue="pair_type", data=d0, ax=ax)
# ax.legend()
ax = axs[1]
sns.stripplot(y="bc", x="pair_type", hue="subject_id", data=d1, ax=ax)
ax.legend_.set_visible(False)


# Fit observed relationship
formula = "bc ~ 0 + C(pair_type) + C(subject_id, Sum) + np.log(time_delta)"
fit = smf.ols(formula, data=d1).fit()
observed_stat = fit.params
observed_stat["EEN - PostEEN"] = (
    observed_stat["C(pair_type)[EEN]"] - observed_stat["C(pair_type)[PostEEN]"]
)
observed_stat["EEN - Transition"] = (
    observed_stat["C(pair_type)[EEN]"] - observed_stat["C(pair_type)[Transition]"]
)
observed_stat["Transition - PostEEN"] = (
    observed_stat["C(pair_type)[Transition]"] - observed_stat["C(pair_type)[PostEEN]"]
)

# Calculate permutations
np.random.seed(1)
perm_stat = {}
for i in tqdm(range(999)):
    # Permute sample-pair labels within subjects
    _perm_type = (
        d0.assign(mgen_id=lambda x: x.index)
        .groupby("subject_id")
        .pair_type.transform(np.random.permutation)
    )
    perm_fit = smf.ols(
        formula,
        data=d0.assign(
            pair_type=_perm_type,
        )[lambda x: x.same_subject],
    ).fit()
    _stat = perm_fit.params
    _stat["EEN - PostEEN"] = _stat["C(pair_type)[EEN]"] - _stat["C(pair_type)[PostEEN]"]
    _stat["EEN - Transition"] = (
        _stat["C(pair_type)[EEN]"] - _stat["C(pair_type)[Transition]"]
    )
    _stat["Transition - PostEEN"] = (
        _stat["C(pair_type)[Transition]"] - _stat["C(pair_type)[PostEEN]"]
    )
    perm_stat[i] = _stat
perm_stat = pd.DataFrame(perm_stat).T

perm_pvalues = _calculate_2tailed_pvalue_from_perm(observed_stat, perm_stat)

fig, axs = plt.subplots(3)
# Plot permutation tests
ax = axs[0]
param_name = "EEN - PostEEN"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
ax = axs[1]
param_name = "EEN - Transition"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
hypoth = perm_stat[param_name] > observed_stat[param_name]
null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
ax = axs[2]
param_name = "Transition - PostEEN"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
hypoth = perm_stat[param_name] > observed_stat[param_name]
null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
fig.tight_layout()

fit.summary()

#### E. lenta

In [None]:
_species_id = "102544"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)
single_species_strain_rabund = _world.sel(
    sample=_world.metagenotype.mean_depth() > 0.5
).community.to_pandas()

# Select data
_rabund, _meta = lib.pandas_util.align_indexes(
    single_species_strain_rabund, meta[meta.sample_type.isin(["EEN", "PostEEN"])]
)

# Calculate pairwise comparisons
bc_cdist = pdist(_rabund, metric="braycurtis")
time_cdist = pdist(
    _meta[["collection_date_relative_een_end"]], metric=lambda x, y: np.abs(x - y)
)
diff_subject_cdist = pdist(_meta[["subject_id"]], metric=lambda x, y: x != y)
same_subject_cdist = (1 - diff_subject_cdist).astype(bool)
type_transition_indicator = pdist(
    _meta[["sample_type"]],
    metric=lambda x, y: (x == "PostEEN").astype(float) + (y == "PostEEN").astype(float),
)
# same_type_cdist = (1 - diff_type_cdist).astype(bool)
pairs = (
    pd.DataFrame(
        squareform(diff_subject_cdist),
        index=_meta.index,
        columns=_meta.index,
    )
    .rename_axis(index="sampleA", columns="sampleB")
    .unstack()
    .index.to_frame()
    .assign(pair=lambda x: x[["sampleA", "sampleB"]].apply(tuple, axis=1))
    .pair.unstack()
)
pairs = pd.Series(pairs.values[np.triu_indices_from(pairs.values, k=1)])
d0 = pd.DataFrame(
    dict(
        sampleA=pairs.str[0],
        sampleB=pairs.str[1],
        bc=bc_cdist,
        same_subject=same_subject_cdist,
        type_transition_indicator=type_transition_indicator,
        diff_type=type_transition_indicator == 1,
        time_delta=time_cdist,
    )
).assign(
    pair_type=lambda x: x.type_transition_indicator.map(
        {0: "EEN", 1: "Transition", 2: "PostEEN"}
    ),
    subject_id=lambda x: x.sampleA.map(_meta.subject_id),
    subject_id_other=lambda x: x.sampleB.map(_meta.subject_id),
)
d1 = d0[lambda x: x.same_subject]


# Plot and observed relationship
fig, axs = plt.subplots(1, 2)
ax = axs[0]
sns.boxenplot(y="bc", x="same_subject", hue="pair_type", data=d0, ax=ax)
# ax.legend()
ax = axs[1]
sns.stripplot(y="bc", x="pair_type", hue="subject_id", data=d1, ax=ax)
ax.legend_.set_visible(False)


# Fit observed relationship
formula = "bc ~ 0 + C(pair_type) + C(subject_id, Sum) + np.log(time_delta)"
fit = smf.ols(formula, data=d1).fit()
observed_stat = fit.params
observed_stat["EEN - PostEEN"] = (
    observed_stat["C(pair_type)[EEN]"] - observed_stat["C(pair_type)[PostEEN]"]
)
observed_stat["EEN - Transition"] = (
    observed_stat["C(pair_type)[EEN]"] - observed_stat["C(pair_type)[Transition]"]
)
observed_stat["Transition - PostEEN"] = (
    observed_stat["C(pair_type)[Transition]"] - observed_stat["C(pair_type)[PostEEN]"]
)

# Calculate permutations
np.random.seed(1)
perm_stat = {}
for i in tqdm(range(999)):
    # Permute sample-pair labels within subjects
    _perm_type = (
        d0.assign(mgen_id=lambda x: x.index)
        .groupby("subject_id")
        .pair_type.transform(np.random.permutation)
    )
    perm_fit = smf.ols(
        formula,
        data=d0.assign(
            pair_type=_perm_type,
        )[lambda x: x.same_subject],
    ).fit()
    _stat = perm_fit.params
    _stat["EEN - PostEEN"] = _stat["C(pair_type)[EEN]"] - _stat["C(pair_type)[PostEEN]"]
    _stat["EEN - Transition"] = (
        _stat["C(pair_type)[EEN]"] - _stat["C(pair_type)[Transition]"]
    )
    _stat["Transition - PostEEN"] = (
        _stat["C(pair_type)[Transition]"] - _stat["C(pair_type)[PostEEN]"]
    )
    perm_stat[i] = _stat
perm_stat = pd.DataFrame(perm_stat).T

perm_pvalues = _calculate_2tailed_pvalue_from_perm(observed_stat, perm_stat)

fig, axs = plt.subplots(3)
# Plot permutation tests
ax = axs[0]
param_name = "EEN - PostEEN"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
ax = axs[1]
param_name = "EEN - Transition"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
hypoth = perm_stat[param_name] > observed_stat[param_name]
null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
ax = axs[2]
param_name = "Transition - PostEEN"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
hypoth = perm_stat[param_name] > observed_stat[param_name]
null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
fig.tight_layout()

fit.summary()

#### C. scindens

In [None]:
_species_id = "101303"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)
single_species_strain_rabund = _world.sel(
    sample=_world.metagenotype.mean_depth() > 0.5
).community.to_pandas()

# Select data
_rabund, _meta = lib.pandas_util.align_indexes(
    single_species_strain_rabund, meta[meta.sample_type.isin(["EEN", "PostEEN"])]
)

# Calculate pairwise comparisons
bc_cdist = pdist(_rabund, metric="braycurtis")
time_cdist = pdist(
    _meta[["collection_date_relative_een_end"]], metric=lambda x, y: np.abs(x - y)
)
diff_subject_cdist = pdist(_meta[["subject_id"]], metric=lambda x, y: x != y)
same_subject_cdist = (1 - diff_subject_cdist).astype(bool)
type_transition_indicator = pdist(
    _meta[["sample_type"]],
    metric=lambda x, y: (x == "PostEEN").astype(float) + (y == "PostEEN").astype(float),
)
# same_type_cdist = (1 - diff_type_cdist).astype(bool)
pairs = (
    pd.DataFrame(
        squareform(diff_subject_cdist),
        index=_meta.index,
        columns=_meta.index,
    )
    .rename_axis(index="sampleA", columns="sampleB")
    .unstack()
    .index.to_frame()
    .assign(pair=lambda x: x[["sampleA", "sampleB"]].apply(tuple, axis=1))
    .pair.unstack()
)
pairs = pd.Series(pairs.values[np.triu_indices_from(pairs.values, k=1)])
d0 = pd.DataFrame(
    dict(
        sampleA=pairs.str[0],
        sampleB=pairs.str[1],
        bc=bc_cdist,
        same_subject=same_subject_cdist,
        type_transition_indicator=type_transition_indicator,
        diff_type=type_transition_indicator == 1,
        time_delta=time_cdist,
    )
).assign(
    pair_type=lambda x: x.type_transition_indicator.map(
        {0: "EEN", 1: "Transition", 2: "PostEEN"}
    ),
    subject_id=lambda x: x.sampleA.map(_meta.subject_id),
    subject_id_other=lambda x: x.sampleB.map(_meta.subject_id),
)
d1 = d0[lambda x: x.same_subject]


# Plot and observed relationship
fig, axs = plt.subplots(1, 2)
ax = axs[0]
sns.boxenplot(y="bc", x="same_subject", hue="pair_type", data=d0, ax=ax)
# ax.legend()
ax = axs[1]
sns.stripplot(y="bc", x="pair_type", hue="subject_id", data=d1, ax=ax)
ax.legend_.set_visible(False)


# Fit observed relationship
formula = "bc ~ 0 + C(pair_type) + C(subject_id, Sum) + np.log(time_delta)"
fit = smf.ols(formula, data=d1).fit()
observed_stat = fit.params
observed_stat["EEN - PostEEN"] = (
    observed_stat["C(pair_type)[EEN]"] - observed_stat["C(pair_type)[PostEEN]"]
)
observed_stat["EEN - Transition"] = (
    observed_stat["C(pair_type)[EEN]"] - observed_stat["C(pair_type)[Transition]"]
)
observed_stat["Transition - PostEEN"] = (
    observed_stat["C(pair_type)[Transition]"] - observed_stat["C(pair_type)[PostEEN]"]
)

# Calculate permutations
np.random.seed(1)
perm_stat = {}
for i in tqdm(range(999)):
    # Permute sample-pair labels within subjects
    _perm_type = (
        d0.assign(mgen_id=lambda x: x.index)
        .groupby("subject_id")
        .pair_type.transform(np.random.permutation)
    )
    perm_fit = smf.ols(
        formula,
        data=d0.assign(
            pair_type=_perm_type,
        )[lambda x: x.same_subject],
    ).fit()
    _stat = perm_fit.params
    _stat["EEN - PostEEN"] = _stat["C(pair_type)[EEN]"] - _stat["C(pair_type)[PostEEN]"]
    _stat["EEN - Transition"] = (
        _stat["C(pair_type)[EEN]"] - _stat["C(pair_type)[Transition]"]
    )
    _stat["Transition - PostEEN"] = (
        _stat["C(pair_type)[Transition]"] - _stat["C(pair_type)[PostEEN]"]
    )
    perm_stat[i] = _stat
perm_stat = pd.DataFrame(perm_stat).T

perm_pvalues = _calculate_2tailed_pvalue_from_perm(observed_stat, perm_stat)

fig, axs = plt.subplots(3)
# Plot permutation tests
ax = axs[0]
param_name = "EEN - PostEEN"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
ax = axs[1]
param_name = "EEN - Transition"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
hypoth = perm_stat[param_name] > observed_stat[param_name]
null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
ax = axs[2]
param_name = "Transition - PostEEN"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
hypoth = perm_stat[param_name] > observed_stat[param_name]
null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
fig.tight_layout()

fit.summary()

#### F. plautii

In [None]:
_species_id = "100099"
path = f"data/group/een/species/sp-{_species_id}/r.proc.gtpro.sfacts-fit.world.nc"
_world = sf.data.World.load(path)
single_species_strain_rabund = _world.sel(
    sample=_world.metagenotype.mean_depth() > 0.5
).community.to_pandas()

# Select data
_rabund, _meta = lib.pandas_util.align_indexes(
    single_species_strain_rabund, meta[meta.sample_type.isin(["EEN", "PostEEN"])]
)

# Calculate pairwise comparisons
bc_cdist = pdist(_rabund, metric="braycurtis")
time_cdist = pdist(
    _meta[["collection_date_relative_een_end"]], metric=lambda x, y: np.abs(x - y)
)
diff_subject_cdist = pdist(_meta[["subject_id"]], metric=lambda x, y: x != y)
same_subject_cdist = (1 - diff_subject_cdist).astype(bool)
type_transition_indicator = pdist(
    _meta[["sample_type"]],
    metric=lambda x, y: (x == "PostEEN").astype(float) + (y == "PostEEN").astype(float),
)
# same_type_cdist = (1 - diff_type_cdist).astype(bool)
pairs = (
    pd.DataFrame(
        squareform(diff_subject_cdist),
        index=_meta.index,
        columns=_meta.index,
    )
    .rename_axis(index="sampleA", columns="sampleB")
    .unstack()
    .index.to_frame()
    .assign(pair=lambda x: x[["sampleA", "sampleB"]].apply(tuple, axis=1))
    .pair.unstack()
)
pairs = pd.Series(pairs.values[np.triu_indices_from(pairs.values, k=1)])
d0 = pd.DataFrame(
    dict(
        sampleA=pairs.str[0],
        sampleB=pairs.str[1],
        bc=bc_cdist,
        same_subject=same_subject_cdist,
        type_transition_indicator=type_transition_indicator,
        diff_type=type_transition_indicator == 1,
        time_delta=time_cdist,
    )
).assign(
    pair_type=lambda x: x.type_transition_indicator.map(
        {0: "EEN", 1: "Transition", 2: "PostEEN"}
    ),
    subject_id=lambda x: x.sampleA.map(_meta.subject_id),
    subject_id_other=lambda x: x.sampleB.map(_meta.subject_id),
)
d1 = d0[lambda x: x.same_subject]


# Plot and observed relationship
fig, axs = plt.subplots(1, 2)
ax = axs[0]
sns.boxenplot(y="bc", x="same_subject", hue="pair_type", data=d0, ax=ax)
# ax.legend()
ax = axs[1]
sns.stripplot(y="bc", x="pair_type", hue="subject_id", data=d1, ax=ax)
ax.legend_.set_visible(False)


# Fit observed relationship
formula = "bc ~ 0 + C(pair_type) + C(subject_id, Sum) + np.log(time_delta)"
fit = smf.ols(formula, data=d1).fit()
observed_stat = fit.params
observed_stat["EEN - PostEEN"] = (
    observed_stat["C(pair_type)[EEN]"] - observed_stat["C(pair_type)[PostEEN]"]
)
observed_stat["EEN - Transition"] = (
    observed_stat["C(pair_type)[EEN]"] - observed_stat["C(pair_type)[Transition]"]
)
observed_stat["Transition - PostEEN"] = (
    observed_stat["C(pair_type)[Transition]"] - observed_stat["C(pair_type)[PostEEN]"]
)

# Calculate permutations
np.random.seed(1)
perm_stat = {}
for i in tqdm(range(999)):
    # Permute sample-pair labels within subjects
    _perm_type = (
        d0.assign(mgen_id=lambda x: x.index)
        .groupby("subject_id")
        .pair_type.transform(np.random.permutation)
    )
    perm_fit = smf.ols(
        formula,
        data=d0.assign(
            pair_type=_perm_type,
        )[lambda x: x.same_subject],
    ).fit()
    _stat = perm_fit.params
    _stat["EEN - PostEEN"] = _stat["C(pair_type)[EEN]"] - _stat["C(pair_type)[PostEEN]"]
    _stat["EEN - Transition"] = (
        _stat["C(pair_type)[EEN]"] - _stat["C(pair_type)[Transition]"]
    )
    _stat["Transition - PostEEN"] = (
        _stat["C(pair_type)[Transition]"] - _stat["C(pair_type)[PostEEN]"]
    )
    perm_stat[i] = _stat
perm_stat = pd.DataFrame(perm_stat).T

perm_pvalues = _calculate_2tailed_pvalue_from_perm(observed_stat, perm_stat)

fig, axs = plt.subplots(3)
# Plot permutation tests
ax = axs[0]
param_name = "EEN - PostEEN"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
ax = axs[1]
param_name = "EEN - Transition"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
hypoth = perm_stat[param_name] > observed_stat[param_name]
null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
ax = axs[2]
param_name = "Transition - PostEEN"
ax.set_title(param_name)
ax.hist(perm_stat[param_name], bins=20)
ax.axvline(observed_stat[param_name])
hypoth = perm_stat[param_name] > observed_stat[param_name]
null_p = (hypoth.sum() + 1) / (len(hypoth) + 1)
ax.annotate(
    perm_pvalues[param_name],
    xy=(0.95, 0.95),
    xycoords="axes fraction",
    ha="right",
    va="top",
)
fig.tight_layout()

fit.summary()

In [None]:
fit.params

In [None]:
d2 = d1.assign(
    predict=fit.predict(),
    resid=fit.resid_pearson,
    influence=fit.get_influence().summary_frame().cooks_d,
)

fig, ax = plt.subplots()
for pair_type, d3 in d2.groupby("pair_type"):
    xx = np.linspace(10, 500)
    logxx = np.log(xx)
    yy = fit.params[f'C(pair_type)[{pair_type}]'] + logxx * fit.params[f'np.log(time_delta)']
    ax.scatter(
        "time_delta",
        "bc",
        label=pair_type,
        data=d3,
    )
    ax.plot(xx, yy)
ax.legend()
ax.set_ylabel("Within-Subjects\nBray-Curtis Distance")
ax.set_xlabel("Within-Subjects Days between Samples")
ax.set_xscale("symlog")

fig, ax = plt.subplots()
art = ax.scatter(
    "time_delta",
    "resid",
    c="influence",
    data=d2,
)
fig.colorbar(art, label="Cook's D")
ax.set_ylabel("Residual BC\n(standardized)")
ax.set_xlabel("Within-Subjects Days between Samples")
ax.set_xscale("symlog")