## Preamble

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

### Imports

In [None]:
import os
import subprocess
import time
from itertools import chain, product
from tempfile import mkstemp
from warnings import filterwarnings

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import sfacts as sf
import statsmodels.formula.api as smf
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable

# from fastcluster import linkage
from scipy.cluster.hierarchy import linkage
from scipy.spatial.distance import pdist, squareform
from statsmodels.graphics.regressionplots import influence_plot
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
import lib.thisproject.data
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

In [None]:
def linkage_order(linkage, labels):
    return labels[sp.cluster.hierarchy.to_tree(linkage).pre_order(lambda x: x.id)]

In [None]:
def calculate_2tailed_pvalue_from_perm(obs, perms):
    hypoth_left = perms > obs
    hypoth_right = perms < obs
    null_p_left = (hypoth_left.sum() + 1) / (len(hypoth_left) + 1)
    null_p_right = (hypoth_right.sum() + 1) / (len(hypoth_right) + 1)
    return np.minimum(null_p_left, null_p_right) * 2

In [None]:
def plot_stacked_barplot(data, x_var, order, palette=None, ax=None, **kwargs):
    if ax is None:
        ax = plt.subplot()
    if palette is None:
        palette = lib.plot.construct_ordered_palette(order)

    # Bar styles
    bar_kwargs = dict(
        width=1.0,
        alpha=1.0,
        edgecolor="k",
        lw=1,
    )
    bar_kwargs.update(kwargs)

    # Plot each bar segment
    _last_top = 0
    for y_var in order:
        ax.bar(
            x=data[x_var],
            height=data[y_var],
            bottom=_last_top,
            label=y_var,
            color=palette[y_var],
            **bar_kwargs,
        )
        _last_top += data[y_var]
    ax.set_xticks(data[x_var].values)
    return ax


def rename_timepoints_for_ts(old_tp_label):
    if isinstance(old_tp_label, float):
        if np.isnan(old_tp_label):
            return ""
        else:
            assert False, "No idea what's going on here."
    if old_tp_label == "E0":
        return "pE"
    elif old_tp_label.startswith("E"):
        return old_tp_label.replace("E0", "EE")
    elif old_tp_label.startswith("Po0"):
        return old_tp_label.replace("Po0", "PE")
    elif old_tp_label.startswith("Po"):
        return old_tp_label.replace("Po", "PE")


plot_stacked_barplot(
    pd.DataFrame(dict(t=[0, 1, 2], y1=[0.0, 0.5, 1.0], y2=[1.0, 0.5, 0.0])),
    x_var="t",
    order=["y1", "y2"],
)

In [None]:
sns.set_context("paper")
plt.rcParams["figure.dpi"] = 100

# Prepare Data

## Load Metadata

In [None]:
pair_type_palette = {
    "EEN:PostEEN": "tab:green",
    "EEN": "tab:blue",
    "PostEEN": "tab:orange",
}

diet_palette = {
    "EEN": "lightgreen",
    "PostEEN": "lightblue",
    "InVitro": "plum",
    "PreEEN": "lightpink",
}

subject_order = [
    "A",
    "B",
    "H",
    "C",
    "D",
    "E",
    "F",
    "G",
    "K",
    "L",
    "M",
    "N",
    "O",
    "P",
    "Q",
    "R",
    "S",
    "T",
    "U",
]

# NOTE: Requires a dummy value because I want exactly 20 items.
subject_palette = lib.plot.construct_ordered_palette(
    subject_order + [f"dummy{i}" for i in range(20 - len(subject_order))], cm="tab20"
)
subject_palette["X"] = "black"
pair_type_order = ["EEN", "EEN:PostEEN", "PostEEN"]
pair_type_marker_palette = {"EEN": "s", "EEN:PostEEN": ">", "PostEEN": "o"}
pair_type_linestyle_palette = {"EEN": ":", "EEN:PostEEN": "-.", "PostEEN": "-"}

In [None]:
# List of all species with pangenome profiles

species_list = pd.read_table("meta/species_group.tsv", dtype=str)[
    lambda x: x.species_group_id == "een"
].species_id
assert species_list.is_unique
species_list = list(species_list)

In [None]:
sample = (
    pd.read_table("meta/een-mgen/sample.tsv")
    .assign(
        label=lambda x: x[
            ["collection_date_relative_een_end", "diet_or_media", "sample_id"]
        ].apply(tuple, axis=1)
    )
    .set_index("sample_id")
)
subject = pd.read_table("meta/een-mgen/subject.tsv", index_col="subject_id")

In [None]:
# Taxonomy embedded in the counts table file.
rotu_counts0 = pd.read_table(
    "data/group/een/a.proc.zotu_counts.tsv", index_col="#OTU ID"
).rename_axis(index="zotu", columns="sample_id")

rotu_taxonomy0 = rotu_counts0.taxonomy
rotu_taxonomy = rotu_taxonomy0.str.split(";").apply(
    lambda x: pd.Series(x, index=["d__", "p__", "c__", "o__", "f__", "g__", "s__"])
)

In [None]:
motu_taxonomy_inpath = "ref/uhgg_genomes_all_v2.tsv"

_motu_taxonomy = (
    pd.read_table(motu_taxonomy_inpath)[lambda x: x.Genome == x.Species_rep]
    .assign(species_id=lambda x: "1" + x.MGnify_accession.str.split("-").str[2])
    .set_index("species_id")
)

motu_lineage_string = _motu_taxonomy.Lineage


def _parse_taxonomy_string(taxonomy_string):
    values = taxonomy_string.split(";")
    return pd.Series(values, index=["d__", "p__", "c__", "o__", "f__", "g__", "s__"])


motu_taxonomy = _motu_taxonomy.Lineage.apply(
    _parse_taxonomy_string
)  # .assign(taxonomy_string=motu_lineage_string)
motu_taxonomy

In [None]:
cog_meta = pd.read_table(
    "ref/cog-20.meta.tsv",
    encoding="latin10",
    names=["cog", "categories", "description", "gene_name", "pathway", "_5", "color"],
    index_col="cog",
)
cog_meta

In [None]:
cog_category_meta = pd.read_table(
    "ref/cog-20.categories.tsv",
    names=["category", "color", "description"],
    index_col="category",
)
cog_category_meta

In [None]:
# Ubiquitous, single-copy genes to be used for estimating total genome depth:

schg_cog_list = [
    "COG0012",
    "COG0016",
    "COG0048",
    "COG0049",
    "COG0052",
    "COG0080",
    "COG0081",
    "COG0085",
    "COG0087",
    "COG0088",
    "COG0090",
    "COG0091",
    "COG0092",
    "COG0093",
    "COG0094",
    "COG0096",
    "COG0097",
    "COG0098",
    "COG0099",
    "COG0100",
    "COG0102",
    "COG0103",
    "COG0124",
    "COG0184",
    "COG0185",
    "COG0186",
    "COG0197",
    "COG0200",
    "COG0201",
    "COG0256",
    "COG0495",
    "COG0522",
    "COG0525",
    "COG0533",
    "COG0542",  # This one is a depth outlier...
]

cog_meta.loc[schg_cog_list]

## Load Data

### zOTUs

In [None]:
rotu_counts = pd.read_table(
    "data/group/een/a.proc.zotu_counts.tsv", index_col="#OTU ID"
).rename_axis(index="zotu", columns="sample_id")
rotu_counts = rotu_counts.drop(columns=["taxonomy"]).T
rotu_rabund = rotu_counts.divide(rotu_counts.sum(1), axis=0)

sample_rotu_bc_cdmat = sp.spatial.distance.pdist(rotu_rabund, "braycurtis")
sample_rotu_bc_pdist = pd.DataFrame(
    squareform(sample_rotu_bc_cdmat), index=rotu_rabund.index, columns=rotu_rabund.index
)
sample_rotu_bc_linkage = sp.cluster.hierarchy.linkage(
    sample_rotu_bc_cdmat, method="average", optimal_ordering=True
)

rotu_rabund

### Species

In [None]:
motu_depth = (
    pd.read_table(
        "data/group/een/r.proc.gene99_v20-v23-agg75.spgc_specgene-ref-filt-p95.all_species_depth.tsv",
        index_col=["sample", "species_id"],
    )
    .depth.unstack(fill_value=0)
    .rename(columns=str, index=lambda x: "CF_" + str(int(x.split("_")[1])))
    .rename({"CF_15": "CF_11", "CF_11": "CF_15"})
)
motu_rabund = motu_depth.divide(motu_depth.sum(1), axis=0)

sample_motu_bc_cdmat = sp.spatial.distance.pdist(motu_rabund, "braycurtis")
sample_motu_bc_pdist = pd.DataFrame(
    squareform(sample_motu_bc_cdmat), index=motu_rabund.index, columns=motu_rabund.index
)
sample_motu_bc_linkage = sp.cluster.hierarchy.linkage(
    sample_motu_bc_cdmat, method="average", optimal_ordering=True
)

motu_rabund

In [None]:
motu_rabund_thresh = 1e-4
motu_prevalence_by_subject = (
    motu_rabund.gt(motu_rabund_thresh)
    .groupby(sample[lambda x: x.sample_type == "human"].subject_id)
    .mean()
)

### Strains

In [None]:
sotu_depth = []
missing_files = []
for species_id in motu_depth.columns:
    path = f"data/group/een/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.comm.tsv"
    try:
        d = (
            pd.read_table(path, index_col=["sample", "strain"])
            .squeeze()
            .unstack()
            .rename(columns=str, index=lambda x: "CF_" + str(int(x.split("_")[1])))
            .rename({"CF_11": "CF_15", "CF_15": "CF_11"})  # Sample swap.
        )
    except FileNotFoundError:
        missing_files.append(path)
        d = pd.DataFrame([])
    _keep_strains = idxwhere(d.sum() > 0.05)
    assert d.index.isin(motu_depth.index).all()
    d = d.reindex(index=motu_depth.index, columns=_keep_strains, fill_value=0)
    d = d.assign(__other=lambda x: 1 - x.sum(1)).rename(columns={"__other": -1})
    d[d < 0] = 0
    d = d.divide(d.sum(1), axis=0)
    d = d.multiply(motu_depth[species_id], axis=0)
    d = d.rename(columns=lambda s: f"{species_id}_{s}")
    sotu_depth.append(d)
sotu_depth = pd.concat(sotu_depth, axis=1)
sotu_rabund = sotu_depth.divide(sotu_depth.sum(1), axis=0)

sample_sotu_bc_cdmat = sp.spatial.distance.pdist(sotu_rabund, metric="braycurtis")
sample_sotu_bc_pdist = pd.DataFrame(
    squareform(sample_sotu_bc_cdmat), index=sotu_rabund.index, columns=sotu_rabund.index
)
sample_sotu_bc_linkage = sp.cluster.hierarchy.linkage(
    sample_sotu_bc_pdist, method="average", optimal_ordering=True
)

sns.clustermap(
    sample_sotu_bc_pdist,
    col_linkage=sample_sotu_bc_linkage,
    row_linkage=sample_sotu_bc_linkage,
)

### Genes

In [None]:
# Load table of gene depths for each species and aggregate by COG.
# Can take up to 8 minutes to compile everything
cog_depth = {}

for species in tqdm(species_list):
    gene_x_cog_inpath = (
        f"data/species/sp-{species}/midasdb_v20.emapper.gene75_x_cog.tsv"
    )
    gene_depth_inpath = (
        f"data/group/een/species/sp-{species}/r.proc.gene99_v20-v23-agg75.depth2.nc"
    )
    _gene_x_cog = (
        pd.read_table(gene_x_cog_inpath)
        .drop_duplicates()
        .set_index("centroid_75")
        .squeeze()
    )

    # Calculate the depth of each COG by summing all genes labeled as that COG.
    _cog_depth = (
        xr.load_dataarray(gene_depth_inpath)
        .to_pandas()
        .T.join(_gene_x_cog)
        .groupby("cog")
        .sum()
    )
    cog_depth[species] = _cog_depth.stack()

cog_depth = (
    pd.DataFrame(cog_depth)
    .stack()
    .rename_axis(["cog", "sample", "species"])
    .to_xarray()
    .fillna(0)
)

# Normalize sample names and swap the mislabeled samples.
cog_depth["sample"] = (
    cog_depth.sample.to_series()
    .map(lambda x: "CF_" + str(int(x.split("_")[1])))
    .replace({"CF_15": "CF_11", "CF_11": "CF_15"})
    .values
)

# Analyses

## Species Enrichment

In [None]:
def enrichment_test(d):
    try:
        res = sp.stats.wilcoxon(d["PostEEN"], d["EEN"])
    except ValueError:
        res = (np.nan, np.nan)
    log2_ratio = np.log2(d["PostEEN"] / d["EEN"])
    return pd.Series(
        [log2_ratio.mean(), d["EEN"].mean(), d["PostEEN"].mean(), res[1]],
        index=["log2_ratio", "mean_EEN", "mean_PostEEN", "pvalue"],
    )

In [None]:
motu_enrichment_results = (
    motu_rabund.apply(lambda x: x + x.replace({0: np.inf}).min())
    .join(sample[["subject_id", "diet_or_media"]])
    .groupby(["subject_id", "diet_or_media"])
    .mean()
    .stack()
    .unstack("diet_or_media")[["EEN", "PostEEN"]]
    .dropna()
    # .assign(log2_ratio=lambda x: np.log2(x["EEN"] / x["PostEEN"]))
    .rename_axis(index=["subject_id", "motu_id"])
    .groupby(level="motu_id")
    .apply(enrichment_test)
)

In [None]:
motu_enrichment_results_with_fdr = (
    motu_enrichment_results.dropna(subset=["pvalue"])[
        lambda x: (x.mean_EEN > 1e-3) | (x.mean_PostEEN > 1e-3)
    ]
    .assign(
        fdr=lambda x: fdrcorrection(x.pvalue)[1],
        hit=lambda x: (True & (x.fdr < 0.1)),
    )
    .sort_values("fdr")
)
motu_enrichment_results_with_fdr.sort_values("pvalue")

In [None]:
fig, ax = plt.subplots()
ax.scatter("log2_ratio", "pvalue", c="hit", data=motu_enrichment_results_with_fdr, s=5)
ax.set_yscale("log")
ax.yaxis.set_inverted(True)

In [None]:
motu_enrichment_results_with_fdr.loc["101493"]

In [None]:
motu_mean_rabund = (
    motu_enrichment_results.mean_EEN + motu_enrichment_results.mean_PostEEN
) / 2
plt.hist(motu_mean_rabund)

## COG Enrichment

### Detection Limit Imputation

In [None]:
cog_detection_limit = cog_depth.where(lambda x: x != 0, np.inf).min(
    ("sample", "species")
)
undetected_cogs_list = idxwhere((cog_detection_limit == np.inf).to_series())

cog_depth_or_detection_limit = cog_depth.where(
    lambda x: x != 0, cog_detection_limit
).drop_sel(cog=undetected_cogs_list)

cog_depth_or_detection_limit

### Normalize COG depth

In [None]:
total_genome_depth = (
    cog_depth_or_detection_limit.sel(cog=schg_cog_list)
    .median("cog")
    .sum("species")  # NOTE: Mean or Median? Does it matter?
)
normalized_cog_depth_by_sample = (
    cog_depth_or_detection_limit.sum("species") / total_genome_depth
)

### Aggregate by Subject

In [None]:
normalized_cog_depth_by_subject_and_type = (
    normalized_cog_depth_by_sample.to_pandas()
    .T.join(sample[["subject_id", "diet_or_media"]])[
        lambda x: x.diet_or_media.isin(["EEN", "PostEEN"])
    ]
    .groupby(["subject_id", "diet_or_media"])
    .median()  # NOTE: Mean or Median?
    .unstack("diet_or_media")
    .dropna()
    .stack("diet_or_media")
)
normalized_cog_depth_by_subject_and_type

### Perform Tests

In [None]:
pairwise_cog_test_results = {}
for cog in tqdm(normalized_cog_depth_by_subject_and_type.columns):
    d = normalized_cog_depth_by_subject_and_type[cog].unstack()
    mean_een = d.EEN.mean()
    mean_post = d.PostEEN.mean()
    mean_log2_ratio = np.log2(d.PostEEN / d.EEN).mean()
    median_log2_ratio = np.log2(d.PostEEN / d.EEN).median()
    try:
        result = sp.stats.wilcoxon(
            d.PostEEN,
            d.EEN,
        )
        pval = result.pvalue
    except ValueError:
        pval = np.nan
    pairwise_cog_test_results[cog] = (
        mean_een,
        mean_post,
        mean_log2_ratio,
        median_log2_ratio,
        pval,
    )

pairwise_cog_test_results = pd.DataFrame(
    pairwise_cog_test_results,
    index=("mean_een", "mean_post", "mean_log2_ratio", "median_log2_ratio", "pval"),
).T

### Calculate FDR

In [None]:
# Here is where I define filters on COGs:
#    ~~They must have a mean depth during one of the two time-periods of > 0.01~~

pairwise_cog_test_results_filt_with_fdr = (
    pairwise_cog_test_results.dropna(subset=["pval"])[
        lambda x: (x.mean_een > 0.01) | (x.mean_post > 0.01)
    ]
    .assign(
        fdr=lambda x: fdrcorrection(x.pval)[1],
        hit=lambda x: (
            True
            & (x.fdr < 0.1)
            # & (np.abs(x.mean_log2_ratio) > 0.2)
        ),
    )
    .sort_values("fdr")
)
pairwise_cog_test_results_filt_with_fdr.sort_values("pval")

In [None]:
pairwise_cog_test_results_filt_with_fdr[
    lambda x: (x.median_log2_ratio < 0) & x.hit
].shape

In [None]:
pairwise_cog_test_results_filt_with_fdr[
    lambda x: (x.median_log2_ratio > 0) & x.hit
].shape

In [None]:
d = pairwise_cog_test_results_filt_with_fdr.sort_values("pval").join(cog_meta)

fig, ax = plt.subplots(figsize=(3, 4))
ax.scatter(
    "mean_log2_ratio",
    "pval",
    data=d[d.hit & (d.mean_log2_ratio < 0)],
    color="tab:blue",
    s=10,
    alpha=0.7,
)
ax.scatter(
    "mean_log2_ratio",
    "pval",
    data=d[d.hit & (d.mean_log2_ratio > 0)],
    color="tab:orange",
    s=10,
    alpha=0.7,
)
ax.scatter("mean_log2_ratio", "pval", data=d[~d.hit], color="grey", s=5)
ax.invert_yaxis()
ax.set_yscale("log")
ax.axvline(0.0, color="black", lw=1, linestyle="--")
# ax.axvline(-0.2, color="black", lw=1, linestyle="--")
# ax.axhline(0.05, color="black", lw=1, linestyle="--")
ax.set_xlabel("Mean Log2(Fold-change)")
ax.set_ylabel("P-value")
fig.tight_layout()

fig.savefig("fig/een_gene_abundance_test.tall.pdf", bbox_inches="tight")

In [None]:
pairwise_cog_test_results_filt_with_fdr.loc[
    ["COG3845", "COG4603", "COG4813", "COG3867"]
]

## Turnover Analysis

In [None]:
def pair_classifier(sample_typeA, sample_typeB):
    return ":".join(sorted(set([sample_typeA, sample_typeB])))


def construct_turnover_analysis_data(
    dmat,
    meta,
    sample_type_var,
    stratum_var=None,
    time_var=None,
):
    var_list = []
    for var in [sample_type_var, stratum_var, time_var]:
        if var is not None:
            var_list.append(var)
    meta = meta.reindex(dmat.index)[var_list].dropna()
    dmat = dmat.loc[meta.index, meta.index]
    data = []
    for (i, idxA), (j, idxB) in product(enumerate(meta.index), repeat=2):
        pair_data = {
            "sampleA": idxA,
            "sampleB": idxB,
            "sample_typeA": meta.loc[idxA, sample_type_var],
            "sample_typeB": meta.loc[idxB, sample_type_var],
            "diss": dmat.loc[idxA, idxB],
        }
        if stratum_var is not None:
            pair_data.update(
                {
                    "stratumA": meta.loc[idxA, stratum_var],
                    "stratumB": meta.loc[idxB, stratum_var],
                }
            )
        if time_var is not None:
            pair_data.update(
                {"timeA": meta.loc[idxA, time_var], "timeB": meta.loc[idxB, time_var]}
            )
        data.append(pair_data)
    data = pd.DataFrame(
        data,
    )
    data = data.assign(
        pair_type=lambda x: x.apply(
            lambda y: pair_classifier(y.sample_typeA, y.sample_typeB), axis=1
        )
    )

    if time_var:
        data = data.assign(time_delta=lambda x: np.abs(x.timeB - x.timeA))

    data = data[lambda x: (x.stratumA == x.stratumB) & (x.sampleA < x.sampleB)]
    if stratum_var:
        data = data.assign(stratum=lambda x: x.stratumA)

    return data

### All Species (zOTUs)

In [None]:
formula = "diss ~ 0 + pair_type + cr(time_delta, 4) + C(stratum, Sum)"
coef_name_list = ["pair_type[EEN]", "pair_type[PostEEN]", "pair_type[EEN:PostEEN]"]
n_perm = 999

turnover_data = construct_turnover_analysis_data(
    sample_rotu_bc_pdist,
    meta=sample[lambda x: x.diet_or_media.isin(["EEN", "PostEEN"])],
    sample_type_var="diet_or_media",
    stratum_var="subject_id",
    time_var="collection_date_relative_een_end",
)

obs_fit = smf.ols(formula, data=turnover_data).fit()
obs_fit.summary()

global_rotu_obs_coefs = obs_fit.params.reindex(coef_name_list)
global_rotu_obs_coefs["num_pairs"] = obs_fit.nobs
global_rotu_obs_coefs["num_subjects_with_pairs"] = turnover_data.stratum.unique().shape[
    0
]
_transition_pairs = turnover_data[lambda x: x.pair_type == "EEN:PostEEN"]
global_rotu_obs_coefs["num_transition_pairs"] = _transition_pairs.shape[0]
global_rotu_obs_coefs[
    "num_subjects_with_transition_pairs"
] = _transition_pairs.stratum.unique().shape[0]

np.random.seed(0)
_null_coef_dists = []
for _ in tqdm(range(n_perm)):
    perm_fit = smf.ols(
        formula,
        data=turnover_data.assign(
            pair_type=lambda x: x.groupby("stratum").pair_type.sample(frac=1).values
        ),
    ).fit()
    _null_coef_dists.append(perm_fit.params.reindex(coef_name_list))

global_rotu_null_coefs = pd.DataFrame(_null_coef_dists)

In [None]:
d1 = turnover_data
fit = obs_fit


# pair_type_palette = {'Transition': 'tab:green', 'EEN': 'tab:blue', 'PostEEN': 'tab:orange'}
# pair_type_marker_palette = {'EEN': 's', 'Transition': '>', 'PostEEN': 'o'}
# pair_type_linestyle_palette = {'EEN': ':', 'Transition': '--', 'PostEEN': '-'}

fig, ax = plt.subplots(figsize=(5, 3))
# ax.set_title("Within-subject Pairwise Turnover")
for pair_type in ["EEN", "PostEEN", "EEN:PostEEN"]:
    d3 = d1[lambda x: (x.pair_type == pair_type)]
    ax.scatter(
        "time_delta",
        "diss",
        label="__nolegend__",
        data=d3,
        color=pair_type_palette[pair_type],
        marker=pair_type_marker_palette[pair_type],
        edgecolor="white",
        lw=0.5,
        alpha=0.75,
        s=40,
    )

_arbitrary_subject = d1.stratum.unique()[1]
predict_data = pd.DataFrame(
    product(
        [_arbitrary_subject],
        ["EEN", "PostEEN", "EEN:PostEEN"],
        np.logspace(1.0, 2.6),
    ),
    columns=["stratum", "pair_type", "time_delta"],
)
predict_data = predict_data.assign(
    prediction=fit.predict(predict_data),
    predict_mean_subject=lambda x: x.prediction
    - fit.params[f"C(stratum, Sum)[S.{_arbitrary_subject}]"],
)
for pair_type in pair_type_order:
    d4 = predict_data[lambda x: x.pair_type == pair_type]
    left, right = d1[lambda x: x.pair_type == pair_type].time_delta.quantile(
        [0.05, 0.95]
    )
    ax.plot(
        "time_delta",
        "predict_mean_subject",
        label="__nolegend__",
        data=d4[lambda x: (x.time_delta > left) & (x.time_delta < right)],
        color="black",
        linestyle=pair_type_linestyle_palette[pair_type],
        lw=2,
    )
ax.set_ylabel("Turnover")
ax.set_xlabel("Days between Samples")
ax.set_xscale("symlog", linthresh=1e-1)

fig.savefig("fig/een_turnover_analysis_zotu_level.pdf", bbox_inches="tight")

fig, ax = plt.subplots(figsize=(2.5, 1.5))
for pair_type in pair_type_order:
    ax.scatter(
        [],
        [],
        label=pair_type,
        color=pair_type_palette[pair_type],
        marker=pair_type_marker_palette[pair_type],
        edgecolor="grey",
        lw=0.5,
        s=50,
    )
for pair_type in pair_type_order:
    ax.plot(
        [],
        [],
        label=pair_type,
        color="k",
        linestyle=pair_type_linestyle_palette[pair_type],
        lw=2,
    )
ax.legend(ncols=2, handlelength=3, markerscale=1.1)
lib.plot.hide_axes_and_spines(ax=ax)

fig.savefig("fig/een_turnover_analysis_legend.pdf", bbox_inches="tight")

In [None]:
_obs = global_rotu_obs_coefs
_null = global_rotu_null_coefs

for coef, color in zip(coef_name_list, ["tab:blue", "tab:orange", "tab:green"]):
    sns.kdeplot(_null[coef], color=color)
    plt.axvline(_obs[coef], color=color, label=coef)
plt.legend()

In [None]:
comparisons = dict(
    transition_vs_een=lambda x: x["pair_type[EEN:PostEEN]"] - x["pair_type[EEN]"],
    transition_vs_pos=lambda x: x["pair_type[EEN:PostEEN]"] - x["pair_type[PostEEN]"],
    een_vs_post=lambda x: x["pair_type[EEN]"] - x["pair_type[PostEEN]"],
    transition_vs_mean=lambda x: x["pair_type[EEN:PostEEN]"]
    - 0.5 * (x["pair_type[EEN]"] + x["pair_type[PostEEN]"]),
)

_obs = global_rotu_obs_coefs
_null = global_rotu_null_coefs


fig, axs = plt.subplots(4, sharex=True)
for comp, ax in zip(comparisons, axs):
    x, y = _null.apply(comparisons[comp], axis=1), comparisons[comp](_obs)
    ax.set_title(comp)
    ax.hist(x)
    ax.axvline(y)
    print(calculate_2tailed_pvalue_from_perm(y, x))

In [None]:
calculate_2tailed_pvalue_from_perm

In [None]:
# TODO: Test for the difference between the transition and each of the others

### All Species (Metagenomics)

In [None]:
formula = "diss ~ 0 + pair_type + cr(time_delta, 4) + C(stratum, Sum)"
coef_name_list = ["pair_type[EEN]", "pair_type[PostEEN]", "pair_type[EEN:PostEEN]"]
n_perm = 999

turnover_data = construct_turnover_analysis_data(
    sample_motu_bc_pdist,
    meta=sample[lambda x: x.diet_or_media.isin(["EEN", "PostEEN"])],
    sample_type_var="diet_or_media",
    stratum_var="subject_id",
    time_var="collection_date_relative_een_end",
)

obs_fit = smf.ols(formula, data=turnover_data).fit()
obs_fit.summary()

global_motu_obs_coefs = obs_fit.params.reindex(coef_name_list)
global_motu_obs_coefs["num_pairs"] = obs_fit.nobs
global_motu_obs_coefs["num_subjects_with_pairs"] = turnover_data.stratum.unique().shape[
    0
]
_transition_pairs = turnover_data[lambda x: x.pair_type == "EEN:PostEEN"]
global_motu_obs_coefs["num_transition_pairs"] = _transition_pairs.shape[0]
global_motu_obs_coefs[
    "num_subjects_with_transition_pairs"
] = _transition_pairs.stratum.unique().shape[0]

np.random.seed(0)
_null_coef_dists = []
for _ in tqdm(range(n_perm)):
    perm_fit = smf.ols(
        formula,
        data=turnover_data.assign(
            pair_type=lambda x: x.groupby("stratum").pair_type.sample(frac=1).values
        ),
    ).fit()
    _null_coef_dists.append(perm_fit.params.reindex(coef_name_list))

global_motu_null_coefs = pd.DataFrame(_null_coef_dists)

In [None]:
_obs = global_motu_obs_coefs
_null = global_motu_null_coefs

for coef, color in zip(coef_name_list, ["tab:blue", "tab:orange", "tab:green"]):
    sns.kdeplot(_null[coef], color=color)
    plt.axvline(_obs[coef], color=color, label=coef)
plt.legend()

In [None]:
comparisons = dict(
    transition_vs_een=lambda x: x["pair_type[EEN:PostEEN]"] - x["pair_type[EEN]"],
    transition_vs_pos=lambda x: x["pair_type[EEN:PostEEN]"] - x["pair_type[PostEEN]"],
    een_vs_post=lambda x: x["pair_type[EEN]"] - x["pair_type[PostEEN]"],
    transition_vs_mean=lambda x: x["pair_type[EEN:PostEEN]"]
    - 0.5 * (x["pair_type[EEN]"] + x["pair_type[PostEEN]"]),
)

_obs = global_motu_obs_coefs
_null = global_motu_null_coefs


fig, axs = plt.subplots(4, sharex=True)
for comp, ax in zip(comparisons, axs):
    x, y = _null.apply(comparisons[comp], axis=1), comparisons[comp](_obs)
    ax.set_title(comp)
    ax.hist(x)
    ax.axvline(y)
    print(calculate_2tailed_pvalue_from_perm(y, x))

In [None]:
# TODO: Test for the difference between the transition and each of the others

### All Strains

In [None]:
formula = "diss ~ 0 + pair_type + cr(time_delta, 4) + C(stratum, Sum)"
coef_name_list = ["pair_type[EEN]", "pair_type[PostEEN]", "pair_type[EEN:PostEEN]"]
n_perm = 999

turnover_data = construct_turnover_analysis_data(
    sample_sotu_bc_pdist,
    meta=sample[lambda x: x.diet_or_media.isin(["EEN", "PostEEN"])],
    sample_type_var="diet_or_media",
    stratum_var="subject_id",
    time_var="collection_date_relative_een_end",
)

obs_fit = smf.ols(formula, data=turnover_data).fit()
obs_fit.summary()

global_sotu_obs_coefs = obs_fit.params.reindex(coef_name_list)
global_sotu_obs_coefs["num_pairs"] = obs_fit.nobs
global_sotu_obs_coefs["num_subjects_with_pairs"] = turnover_data.stratum.unique().shape[
    0
]
_transition_pairs = turnover_data[lambda x: x.pair_type == "EEN:PostEEN"]
global_sotu_obs_coefs["num_transition_pairs"] = _transition_pairs.shape[0]
global_sotu_obs_coefs[
    "num_subjects_with_transition_pairs"
] = _transition_pairs.stratum.unique().shape[0]

np.random.seed(0)
_null_coef_dists = []
for _ in tqdm(range(n_perm)):
    perm_fit = smf.ols(
        formula,
        data=turnover_data.assign(
            pair_type=lambda x: x.groupby("stratum").pair_type.sample(frac=1).values
        ),
    ).fit()
    _null_coef_dists.append(perm_fit.params.reindex(coef_name_list))

global_sotu_null_coefs = pd.DataFrame(_null_coef_dists)

In [None]:
d1 = turnover_data
fit = obs_fit


# pair_type_palette = {'Transition': 'tab:green', 'EEN': 'tab:blue', 'PostEEN': 'tab:orange'}
# pair_type_marker_palette = {'EEN': 's', 'Transition': '>', 'PostEEN': 'o'}
# pair_type_linestyle_palette = {'EEN': ':', 'Transition': '--', 'PostEEN': '-'}

fig, ax = plt.subplots(figsize=(5, 3))
# ax.set_title("Within-subject Pairwise Turnover")
for pair_type in pair_type_order:
    d3 = d1[lambda x: (x.pair_type == pair_type)]
    ax.scatter(
        "time_delta",
        "diss",
        label="__nolegend__",
        data=d3,
        color=pair_type_palette[pair_type],
        marker=pair_type_marker_palette[pair_type],
        edgecolor="white",
        lw=0.5,
        alpha=0.75,
        s=40,
    )

_arbitrary_subject = d1.stratum.unique()[1]
predict_data = pd.DataFrame(
    product(
        [_arbitrary_subject],
        pair_type_order,
        np.logspace(1.0, 2.6),
    ),
    columns=["stratum", "pair_type", "time_delta"],
)
predict_data = predict_data.assign(
    prediction=fit.predict(predict_data),
    predict_mean_subject=lambda x: x.prediction
    - fit.params[f"C(stratum, Sum)[S.{_arbitrary_subject}]"],
)
for pair_type in pair_type_order:
    d4 = predict_data[lambda x: x.pair_type == pair_type]
    left, right = d1[lambda x: x.pair_type == pair_type].time_delta.quantile(
        [0.05, 0.95]
    )
    ax.plot(
        "time_delta",
        "predict_mean_subject",
        label="__nolegend__",
        data=d4[lambda x: (x.time_delta > left) & (x.time_delta < right)],
        color="black",
        linestyle=pair_type_linestyle_palette[pair_type],
        lw=2,
    )
ax.set_ylabel("Turnover")
ax.set_xlabel("Days between Samples")
ax.set_xscale("symlog", linthresh=1e-1)

fig.savefig("fig/een_turnover_analysis_strain_level.pdf", bbox_inches="tight")

fig, ax = plt.subplots(figsize=(2.5, 1.5))
for pair_type in pair_type_order:
    ax.scatter(
        [],
        [],
        label=pair_type,
        color=pair_type_palette[pair_type],
        marker=pair_type_marker_palette[pair_type],
        edgecolor="grey",
        lw=0.5,
        s=50,
    )
for pair_type in pair_type_order:
    ax.plot(
        [],
        [],
        label=pair_type,
        color="k",
        linestyle=pair_type_linestyle_palette[pair_type],
        lw=2,
    )
ax.legend(ncols=2, handlelength=3, markerscale=1.1)
lib.plot.hide_axes_and_spines(ax=ax)

fig.savefig("fig/een_turnover_analysis_legend.pdf", bbox_inches="tight")

In [None]:
_obs = global_sotu_obs_coefs
_null = global_sotu_null_coefs

for coef, color in zip(coef_name_list, ["tab:blue", "tab:orange", "tab:green"]):
    sns.kdeplot(_null[coef], color=color)
    plt.axvline(_obs[coef], color=color, label=coef)
plt.legend()

In [None]:
comparisons = dict(
    transition_vs_een=lambda x: x["pair_type[EEN:PostEEN]"] - x["pair_type[EEN]"],
    transition_vs_pos=lambda x: x["pair_type[EEN:PostEEN]"] - x["pair_type[PostEEN]"],
    een_vs_post=lambda x: x["pair_type[EEN]"] - x["pair_type[PostEEN]"],
    transition_vs_mean=lambda x: x["pair_type[EEN:PostEEN]"]
    - 0.5 * (x["pair_type[EEN]"] + x["pair_type[PostEEN]"]),
)

_obs = global_sotu_obs_coefs
_null = global_sotu_null_coefs


fig, axs = plt.subplots(4, sharex=True)
for comp, ax in zip(comparisons, axs):
    x, y = _null.apply(comparisons[comp], axis=1), comparisons[comp](_obs)
    ax.set_title(comp)
    ax.hist(x)
    ax.axvline(y)
    print(calculate_2tailed_pvalue_from_perm(y, x))

### Per-Species

In [None]:
species_turnover_analysis_details = {}
for species_id in tqdm(motu_rabund.columns):
    inpath = f"data/group/een/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.world.nc"
    if not os.path.exists(inpath):
        print(species_id, "file missing")
        continue

    sf_fit = (
        sf.data.World.load(inpath)
        .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
        .rename_coords(sample={"CF_11": "CF_15", "CF_15": "CF_11"})
        # .drop_low_abundance_strains(0.01)
        # .rename_coords(strain=str)
    )

    comm = sf_fit.community.to_pandas()
    comm_bc_pdist = pd.DataFrame(
        squareform(pdist(comm, metric="braycurtis")),
        index=comm.index,
        columns=comm.index,
    )

    try:
        turnover_data = construct_turnover_analysis_data(
            comm_bc_pdist,
            meta=sample[lambda x: x.diet_or_media.isin(["EEN", "PostEEN"])],
            sample_type_var="diet_or_media",
            stratum_var="subject_id",
            time_var="collection_date_relative_een_end",
        )
    except ValueError:
        print(species_id, "data doesn't work")
        continue
    species_turnover_analysis_details[species_id] = pd.Series(
        dict(
            overall_mean_diss=turnover_data.groupby("stratum").diss.mean().mean(),
            num_pairs=turnover_data.shape[0],
            num_subjects_with_pairs=turnover_data.stratum.unique().shape[0],
            num_transition_pairs=turnover_data[
                lambda x: x.pair_type == "EEN:PostEEN"
            ].shape[0],
            num_subjects_with_transition_pairs=turnover_data[
                lambda x: x.pair_type == "EEN:PostEEN"
            ]
            .stratum.unique()
            .shape[0],
        )
    )

species_turnover_analysis_details = pd.DataFrame(
    species_turnover_analysis_details.values(),
    index=species_turnover_analysis_details.keys(),
)

In [None]:
formula = "diss ~ 0 + pair_type + cr(time_delta, 4) + C(stratum, Sum)"
coef_name_list = ["pair_type[EEN]", "pair_type[PostEEN]", "pair_type[EEN:PostEEN]"]
n_perm = 999  # FIXME: Replace this with 999. Take about 40 minutes.

obs_coefs = {}
null_coef_dists = []
for species_id in tqdm(motu_rabund.columns):
    inpath = f"data/group/een/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.world.nc"
    if not os.path.exists(inpath):
        print(species_id, "file missing")
        continue

    sf_fit = (
        sf.data.World.load(inpath)
        .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
        .rename_coords(sample={"CF_11": "CF_15", "CF_15": "CF_11"})
        # .drop_low_abundance_strains(0.01)
        # .rename_coords(strain=str)
    )

    comm = sf_fit.community.to_pandas()
    comm_bc_pdist = pd.DataFrame(
        squareform(pdist(comm, metric="braycurtis")),
        index=comm.index,
        columns=comm.index,
    )

    try:
        turnover_data = construct_turnover_analysis_data(
            comm_bc_pdist,
            meta=sample[lambda x: x.diet_or_media.isin(["EEN", "PostEEN"])],
            sample_type_var="diet_or_media",
            stratum_var="subject_id",
            time_var="collection_date_relative_een_end",
        )
    except ValueError:
        print(species_id, "data doesn't work")
        continue

    try:
        obs_fit = smf.ols(formula, data=turnover_data).fit()
    except ValueError:
        print(species_id, "fit failed")
        continue
    obs_coefs[species_id] = obs_fit.params.reindex(coef_name_list)

    np.random.seed(0)
    _null_coef_dists = []
    for _ in range(n_perm):
        perm_fit = smf.ols(
            formula,
            data=turnover_data.assign(
                pair_type=lambda x: x.groupby("stratum").pair_type.sample(frac=1).values
            ),
        ).fit()
        _null_coef_dists.append(perm_fit.params.reindex(coef_name_list))
    null_coef_dists.append(
        pd.DataFrame(_null_coef_dists, columns=coef_name_list).assign(
            species_id=species_id
        )
    )

per_species_obs_coefs = pd.DataFrame(obs_coefs).T
per_species_null_coefs = pd.concat(null_coef_dists)

In [None]:
comparisons = dict(
    transition_vs_een=lambda x: x["pair_type[EEN:PostEEN]"] - x["pair_type[EEN]"],
    transition_vs_pos=lambda x: x["pair_type[EEN:PostEEN]"] - x["pair_type[PostEEN]"],
    een_vs_post=lambda x: x["pair_type[EEN]"] - x["pair_type[PostEEN]"],
    transition_vs_mean=lambda x: x["pair_type[EEN:PostEEN]"]
    - 0.5 * (x["pair_type[EEN]"] + x["pair_type[PostEEN]"]),
)

transition_stats_obs = per_species_obs_coefs.dropna().assign(**comparisons)
transition_stats_null = per_species_null_coefs.assign(**comparisons)
_stats = transition_stats_null.join(transition_stats_obs, on='species_id', lsuffix='_obs', rsuffix='_null', how='inner')

species_transition_test_pvalue = _stats.dropna().groupby('species_id').apply(lambda x: calculate_2tailed_pvalue_from_perm(x.transition_vs_mean_obs, x.transition_vs_mean_null))
# for species in per_species_obs_coefs.dropna().index:
#     print(calculate_2tailed_pvalue_from_perm(x, y))

In [None]:
fig, ax = plt.subplots()

ax.scatter(transition_stats_obs.transition_vs_mean, species_transition_test_pvalue)
ax.set_yscale('log')
ax.yaxis.set_inverted(True)
ax.set_xlim(-2.5, 2.5)

## zOTU / Metagenomics Matching

In [None]:
from sklearn.cross_decomposition import PLSSVD


def plssvd_cross_mapping(x, y, scaled=False, transpose=False):
    if transpose:
        x, y = y, x
    pls = PLSSVD(n_components=min(*x.shape, *y.shape), scale=scaled).fit(x, y)
    cc = pd.DataFrame(
        pls.x_weights_ @ pls.y_weights_.T, index=x.columns, columns=y.columns
    )
    if transpose:
        cc = cc.T
    return cc, pls


def reciprocal_hits(x, y, score_func):
    x, y = align_indexes(x, y)

    coefs, aux = score_func(x, y)

    cols = coefs.columns
    idxs = coefs.index

    x_weight = np.sqrt((x.mean().loc[idxs]))
    y_weight = np.sqrt((y.mean().loc[cols]))
    weighted_coef = coefs.multiply(x_weight, axis=0).multiply(y_weight, axis=1)

    cols_rank = weighted_coef.rank(1, ascending=False)
    idxs_rank = weighted_coef.rank(0, ascending=False)

    result = (
        (cols_rank * idxs_rank)
        .stack()
        .to_frame("rank_product")
        .assign(coef=coefs.stack(), weighted_coef=weighted_coef.stack())
    )

    return result, coefs, aux

In [None]:
x, y = motu_rabund.copy(), rotu_rabund.copy()
print(len(x), len(y))

reciprocal_hits_results, cc, aux = reciprocal_hits(
    x,
    y,
    score_func=lambda x, y: plssvd_cross_mapping(x, y, scaled=False),
)
reciprocal_hits_results = (
    reciprocal_hits_results.join(motu_taxonomy[["f__", "g__", "s__"]], on="species_id")
    .join(rotu_taxonomy[["f__", "g__"]], lsuffix="mgen", rsuffix="zotu")
    .sort_values("zotu")
)


reciprocal_hits_results[lambda x: x.rank_product <= 2].sort_values(
    "zotu"
).rank_product.value_counts()

In [None]:
reciprocal_hits_results[lambda x: x.rank_product <= 2].assign(
    g_mgen_norm=lambda x: x.g__mgen.str[3:].str.split("_").str[0],
    g_zotu_norm=lambda x: x.g__zotu.str.split("_").str[0],
    g_match=lambda x: x.g_mgen_norm == x.g_zotu_norm,
    f_mgen_norm=lambda x: x.f__mgen.str[3:].str.split("_").str[0],
    f_zotu_norm=lambda x: x.f__zotu.str.split("_").str[0],
    f_match=lambda x: x.f_mgen_norm == x.f_zotu_norm,
    matches=lambda x: (x.f_match | x.g_match),
).matches.mean()

In [None]:
reciprocal_hits_results.xs("Zotu5", level="zotu").sort_values("rank_product").head(2)

In [None]:
reciprocal_hits_filtered = reciprocal_hits_results[
    lambda x: x.rank_product <= 2
].assign(
    g_mgen_norm=lambda x: x.g__mgen.str[3:].str.split("_").str[0],
    g_zotu_norm=lambda x: x.g__zotu.str.split("_").str[0],
    g_match=lambda x: x.g_mgen_norm == x.g_zotu_norm,
    f_mgen_norm=lambda x: x.f__mgen.str[3:].str.split("_").str[0],
    f_zotu_norm=lambda x: x.f__zotu.str.split("_").str[0],
    f_match=lambda x: x.f_mgen_norm == x.f_zotu_norm,
)[
    lambda x: (x.f_match | x.g_match)
]
reciprocal_hits_filtered

#### Supplementary Table S3

In [None]:
d = reciprocal_hits_filtered.join(motu_taxonomy.s__.rename('uhgg_taxonomy')).assign(
    coefficient=lambda x: x.coef.round(3),
    weighted_coefficient=lambda x: x.weighted_coef.round(3),
    rank_product=lambda x: x.rank_product.astype(int),
    uhgg_url=lambda x: "https://www.ebi.ac.uk/metagenomics/genomes/MGYG0000" + x.index.to_frame().species_id.str[1:],
    ezbc_id='',
    network_association='',
)[['coefficient', 'weighted_coefficient', 'rank_product', 'uhgg_taxonomy', 'uhgg_url', 'ezbc_id', 'network_association']]

d.to_csv('fig/een_supplementary_table_s3.tsv', sep='\t')
d

In [None]:
# How many zOTUs or species, and what fraction of relative abundance do these account for?
d.index.to_frame().zotu.unique().shape, d.index.to_frame().species_id.unique().shape, rotu_rabund[d.index.to_frame().zotu.values].sum(1).median()

In [None]:
fig, ax = plt.subplots(figsize=(2.5, 4))

bins = np.linspace(-0.2, 1.0, num=41)
plt.hist(
    reciprocal_hits_results.coef,
    density=True,
    bins=bins,
    histtype="step",
    color="k",
    lw=1,
    linestyle="--",
    label="All",
    zorder=2,
)
plt.hist(
    reciprocal_hits_filtered.coef,
    density=True,
    bins=bins,
    color="tab:blue",
    label="Matches",
    zorder=1,
)
plt.yticks([])
plt.xlabel("zOTU-to-Species\nMatching Coefficient")
plt.ylabel('Density')
plt.legend()
plt.savefig("fig/een_species_matching_correlation_density.pdf", bbox_inches="tight")

#### zOTU5

In [None]:
d = pd.DataFrame(
    dict(
        rotu=rotu_rabund[["Zotu5"]].sum(1),
        motu_A=motu_rabund[["101386"]].sum(1),
        motu_B=motu_rabund[["101493"]].sum(1),
        motu_both=motu_rabund[["101493", "101386"]].sum(1),
    )
).dropna()
print(sp.stats.pearsonr(d["motu_A"], d["rotu"]))
print(sp.stats.pearsonr(d["motu_B"], d["rotu"]))
print(sp.stats.pearsonr(d["motu_both"], d["rotu"]))

fig, ax = plt.subplots(figsize=(4, 4))
plt.scatter(
    "rotu",
    "motu_A",
    data=d,
    label="$\mathit{E. clostridioforme}$",
    s=20,
    alpha=0.7,
    edgecolor="grey",
    color="skyblue",
)
plt.scatter(
    "rotu",
    "motu_B",
    data=d,
    label="$\mathit{E. bolteae}$",
    s=20,
    alpha=0.7,
    edgecolor="grey",
    color="lightcoral",
)
plt.scatter(
    "rotu",
    "motu_both",
    data=d,
    label="Combined",
    s=20,
    alpha=0.7,
    edgecolor="grey",
    color="black",
)
plt.plot([0, 0.5], [0, 0.5])
plt.xlabel("zOTU5 relative abundance")
plt.ylabel("Species relative abundance")
plt.yscale("log")
plt.xscale("log")
# plt.ylim(1e-8, 5e1)
# ax.set_aspect(1)
plt.legend(loc="lower right", fontsize="small", markerscale=2)

plt.savefig("fig/een_zotu5_species_matching.pdf", bbox_inches="tight")

#### sp-101346 (B. uniformis)

In [None]:
d = pd.DataFrame(
    dict(
        rotu_A=rotu_rabund[["Zotu48"]].sum(1),
        rotu_B=rotu_rabund[["Zotu6"]].sum(1),
        # rotu_C=rotu_rabund[["Zotu84"]].sum(1),
        motu=motu_rabund[["101346"]].sum(1),
        rotu_both=rotu_rabund[["Zotu6", "Zotu48"]].sum(1),
    )
).dropna()
print(sp.stats.pearsonr(d["rotu_A"], d["motu"]))
print(sp.stats.pearsonr(d["rotu_B"], d["motu"]))
# print(sp.stats.pearsonr(d["rotu_C"], d["motu"]))
print(sp.stats.pearsonr(d["rotu_both"], d["motu"]))

fig, ax = plt.subplots(figsize=(4, 4))
plt.scatter(
    "motu",
    "rotu_A",
    data=d,
    label="$\mathit{zOTU48}$",
    s=20,
    alpha=0.7,
    edgecolor="grey",
    color="skyblue",
)
plt.scatter(
    "motu",
    "rotu_B",
    data=d,
    label="$\mathit{zOTU6}$",
    s=20,
    alpha=0.7,
    edgecolor="grey",
    color="lightcoral",
)
# plt.scatter(
#     "motu",
#     "rotu_C",
#     data=d,
#     label="$\mathit{zOTU84}$",
#     s=20,
#     alpha=0.7,
#     edgecolor="grey",
#     color="lightgreen",
# )
plt.scatter(
    "motu",
    "rotu_both",
    data=d,
    label="Combined",
    s=20,
    alpha=0.7,
    edgecolor="grey",
    color="black",
)
plt.plot([0, 0.5], [0, 0.5])
plt.xlabel("B. uniformis")
plt.ylabel("zOTU relative abundance")
plt.yscale("log")
plt.xscale("log")
# plt.ylim(1e-8, 5e1)
# ax.set_aspect(1)
plt.legend(loc="upper left", fontsize="small", markerscale=2)

plt.savefig("fig/een_101346_zotu_matching.pdf", bbox_inches="tight")

## Per-Species Strain Time-Series

#### 101493

In [None]:
species = "101493"


print(motu_taxonomy.loc[species])

In [None]:
strain_fit = (
    sf.data.World.load(
        f"data/group/een/species/sp-{species}/r.proc.gtpro.sfacts-fit.world.nc"
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={"CF_11": "CF_15", "CF_15": "CF_11"})
    .drop_low_abundance_strains(0.2)
    .rename_coords(strain=str)
)
print(strain_fit.sizes)

# Genotype similarity ordered palette:
strain_linkage = strain_fit.genotype.linkage(optimal_ordering=True)
strain_order = list(
    linkage_order(
        strain_linkage,
        strain_fit.strain.values,
    )
)
strain_order.remove("-1")  # Drop "other" strain.
strain_order.append("-1")  # Add to end of list
strain_palette = lib.plot.construct_ordered_palette(
    strain_order,
    cm="rainbow",
    extend={"-1": "silver"},
)

sf.evaluation.metagenotype_error2(strain_fit, discretized=True)[0]

In [None]:
np.random.seed(0)

sample_linkage = strain_fit.unifrac_linkage(optimal_ordering=True)
position_ss = strain_fit.random_sample(
    position=min(strain_fit.sizes["position"], 1000)
).position

sf.plot.plot_metagenotype(
    strain_fit.sel(position=position_ss), col_linkage_func=lambda w: sample_linkage
)
sf.plot.plot_community(strain_fit, col_linkage_func=lambda w: sample_linkage)

In [None]:
more_colors_strain_palette = lib.plot.construct_ordered_palette(
    strain_order,  # Linkage order, I believe
    cm="rainbow",
    extend={"-1": "silver"},
    desaturate_levels=[1.0, 0.7, 0.4],
)


sample.timepoint.map(rename_timepoints_for_ts)

fig, axs = plt.subplots(
    2,
    10,
    figsize=(10 * 2.7, 2 * 1.5),
    squeeze=False,
    sharey=True,
    gridspec_kw=dict(hspace=1.8, wspace=0),
)


for subject, ax in zip(subject_order, axs.flatten()):
    subject_comm_sample_list = list(
        set(idxwhere(sample.subject_id == subject)) & set(strain_fit.sample.values)
    )

    try:
        subject_comm = (
            strain_fit.sel(sample=subject_comm_sample_list)
            .drop_low_abundance_strains(0.2, agg_strain_coord="-1")
            .community.to_pandas()
        )
    except ValueError:
        subject_comm = pd.DataFrame([], columns=["-1"])

    sample_list = sample.sort_values("collection_date_relative_een_end")[
        lambda x: (x.subject_id == subject)
        & (x.sample_type == "human")
        & (x.index.isin(subject_comm.index))
    ].index

    if len(sample_list) == 0:
        lib.plot.hide_axes_and_spines(ax=ax)
        continue

    d = (
        sample.reindex(sample_list)
        # .dropna(subset=["collection_date_relative_een_end"])
        # .sort_values("collection_date_relative_een_end")
        .assign(
            t=lambda x: range(len(x)),
        )
    ).join(subject_comm)
    # d.loc[d.index[:num_offset_samples], 't'] -= 0.7  # Offset width

    plot_stacked_barplot(
        data=d,
        x_var="t",
        order=[s for s in strain_order if s in subject_comm.columns],
        palette=more_colors_strain_palette,
        ax=ax,
        width=0.8,
        lw=0,
    )

    ax.set_title(subject)
    ax.set_xticklabels(
        d.timepoint.map(rename_timepoints_for_ts),
        fontsize=12,
    )
    ax.set_aspect(3, anchor="NW")
    ax.set_ylim(0, 1.0)
    lib.plot.rotate_xticklabels(rotation=90, ax=ax, ha="center")
    ax.set_yticks(np.linspace(0, 1.0, num=3))
    ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, symbol="%"))
    ax.set_xlim(d.t.min() - 0.5, d.t.max() + 0.5)
    ax.spines[["right", "top"]].set_visible(False)
    # ax.legend(bbox_to_anchor=(1, 1), ncols=2)

fig.savefig(f"fig/een.strain_timeseries.{species}.pdf", bbox_inches="tight")

### 101386

In [None]:
species = "101386"


print(motu_taxonomy.loc[species])

In [None]:
strain_fit = (
    sf.data.World.load(
        f"data/group/een/species/sp-{species}/r.proc.gtpro.sfacts-fit.world.nc"
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={"CF_11": "CF_15", "CF_15": "CF_11"})
    .drop_low_abundance_strains(0.2)
    .rename_coords(strain=str)
)
print(strain_fit.sizes)

# Genotype similarity ordered palette:
strain_linkage = strain_fit.genotype.linkage(optimal_ordering=True)
strain_order = list(
    linkage_order(
        strain_linkage,
        strain_fit.strain.values,
    )
)
strain_order.remove("-1")  # Drop "other" strain.
strain_order.append("-1")  # Add to end of list
strain_palette = lib.plot.construct_ordered_palette(
    strain_order,
    cm="rainbow",
    extend={"-1": "silver"},
)

sf.evaluation.metagenotype_error2(strain_fit, discretized=True)[0]

In [None]:
np.random.seed(0)

sample_linkage = strain_fit.unifrac_linkage(optimal_ordering=True)
position_ss = strain_fit.random_sample(
    position=min(strain_fit.sizes["position"], 1000)
).position

sf.plot.plot_metagenotype(
    strain_fit.sel(position=position_ss), col_linkage_func=lambda w: sample_linkage
)
sf.plot.plot_community(strain_fit, col_linkage_func=lambda w: sample_linkage)

In [None]:
more_colors_strain_palette = lib.plot.construct_ordered_palette(
    strain_order,  # Linkage order, I believe
    cm="rainbow",
    extend={"-1": "silver"},
    desaturate_levels=[1.0, 0.7, 0.4],
)


sample.timepoint.map(rename_timepoints_for_ts)

fig, axs = plt.subplots(
    2,
    10,
    figsize=(10 * 2.7, 2 * 1.5),
    squeeze=False,
    sharey=True,
    gridspec_kw=dict(hspace=1.8, wspace=0),
)


for subject, ax in zip(subject_order, axs.flatten()):
    subject_comm_sample_list = list(
        set(idxwhere(sample.subject_id == subject)) & set(strain_fit.sample.values)
    )

    try:
        subject_comm = (
            strain_fit.sel(sample=subject_comm_sample_list)
            .drop_low_abundance_strains(0.2, agg_strain_coord="-1")
            .community.to_pandas()
        )
    except ValueError:
        subject_comm = pd.DataFrame([], columns=["-1"])

    sample_list = sample.sort_values("collection_date_relative_een_end")[
        lambda x: (x.subject_id == subject)
        & (x.sample_type == "human")
        & (x.index.isin(subject_comm.index))
    ].index

    if len(sample_list) == 0:
        lib.plot.hide_axes_and_spines(ax=ax)
        continue

    d = (
        sample.reindex(sample_list)
        # .dropna(subset=["collection_date_relative_een_end"])
        # .sort_values("collection_date_relative_een_end")
        .assign(
            t=lambda x: range(len(x)),
        )
    ).join(subject_comm)
    # d.loc[d.index[:num_offset_samples], 't'] -= 0.7  # Offset width

    plot_stacked_barplot(
        data=d,
        x_var="t",
        order=[s for s in strain_order if s in subject_comm.columns],
        palette=more_colors_strain_palette,
        ax=ax,
        width=0.8,
        lw=0,
    )

    ax.set_title(subject)
    ax.set_xticklabels(
        d.timepoint.map(rename_timepoints_for_ts),
        fontsize=12,
    )
    ax.set_aspect(3, anchor="NW")
    ax.set_ylim(0, 1.0)
    lib.plot.rotate_xticklabels(rotation=90, ax=ax, ha="center")
    ax.set_yticks(np.linspace(0, 1.0, num=3))
    ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, symbol="%"))
    ax.set_xlim(d.t.min() - 0.5, d.t.max() + 0.5)
    ax.spines[["right", "top"]].set_visible(False)
    # ax.legend(bbox_to_anchor=(1, 1), ncols=2)

fig.savefig(f"fig/een.strain_timeseries.{species}.pdf", bbox_inches="tight")

### 102506

In [None]:
species = "102506"


print(motu_taxonomy.loc[species])

In [None]:
strain_fit = (
    sf.data.World.load(
        f"data/group/een/species/sp-{species}/r.proc.gtpro.sfacts-fit.world.nc"
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={"CF_11": "CF_15", "CF_15": "CF_11"})
    .drop_low_abundance_strains(0.2)
    .rename_coords(strain=str)
)
print(strain_fit.sizes)

# Genotype similarity ordered palette:
strain_linkage = strain_fit.genotype.linkage(optimal_ordering=True)
strain_order = list(
    linkage_order(
        strain_linkage,
        strain_fit.strain.values,
    )
)
strain_order.remove("-1")  # Drop "other" strain.
strain_order.append("-1")  # Add to end of list
strain_palette = lib.plot.construct_ordered_palette(
    strain_order,
    cm="rainbow",
    extend={"-1": "silver"},
)

sf.evaluation.metagenotype_error2(strain_fit, discretized=True)[0]

In [None]:
np.random.seed(0)

sample_linkage = strain_fit.unifrac_linkage(optimal_ordering=True)
position_ss = strain_fit.random_sample(
    position=min(strain_fit.sizes["position"], 1000)
).position

sf.plot.plot_metagenotype(
    strain_fit.sel(position=position_ss), col_linkage_func=lambda w: sample_linkage
)
sf.plot.plot_community(strain_fit, col_linkage_func=lambda w: sample_linkage)

In [None]:
more_colors_strain_palette = lib.plot.construct_ordered_palette(
    strain_order,  # Linkage order, I believe
    cm="rainbow",
    extend={"-1": "silver"},
    desaturate_levels=[1.0, 0.7, 0.4],
)


sample.timepoint.map(rename_timepoints_for_ts)

fig, axs = plt.subplots(
    2,
    10,
    figsize=(10 * 2.7, 2 * 1.5),
    squeeze=False,
    sharey=True,
    gridspec_kw=dict(hspace=1.8, wspace=0),
)


for subject, ax in zip(subject_order, axs.flatten()):
    subject_comm_sample_list = list(
        set(idxwhere(sample.subject_id == subject)) & set(strain_fit.sample.values)
    )

    try:
        subject_comm = (
            strain_fit.sel(sample=subject_comm_sample_list)
            .drop_low_abundance_strains(0.2, agg_strain_coord="-1")
            .community.to_pandas()
        )
    except ValueError:
        subject_comm = pd.DataFrame([], columns=["-1"])

    sample_list = sample.sort_values("collection_date_relative_een_end")[
        lambda x: (x.subject_id == subject)
        & (x.sample_type == "human")
        & (x.index.isin(subject_comm.index))
    ].index

    if len(sample_list) == 0:
        lib.plot.hide_axes_and_spines(ax=ax)
        continue

    d = (
        sample.reindex(sample_list)
        # .dropna(subset=["collection_date_relative_een_end"])
        # .sort_values("collection_date_relative_een_end")
        .assign(
            t=lambda x: range(len(x)),
        )
    ).join(subject_comm)
    # d.loc[d.index[:num_offset_samples], 't'] -= 0.7  # Offset width

    plot_stacked_barplot(
        data=d,
        x_var="t",
        order=[s for s in strain_order if s in subject_comm.columns],
        palette=more_colors_strain_palette,
        ax=ax,
        width=0.8,
        lw=0,
    )

    ax.set_title(subject)
    ax.set_xticklabels(
        d.timepoint.map(rename_timepoints_for_ts),
        fontsize=12,
    )
    ax.set_aspect(3, anchor="NW")
    ax.set_ylim(0, 1.0)
    lib.plot.rotate_xticklabels(rotation=90, ax=ax, ha="center")
    ax.set_yticks(np.linspace(0, 1.0, num=3))
    ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, symbol="%"))
    ax.set_xlim(d.t.min() - 0.5, d.t.max() + 0.5)
    ax.spines[["right", "top"]].set_visible(False)
    # ax.legend(bbox_to_anchor=(1, 1), ncols=2)

fig.savefig(f"fig/een.strain_timeseries.{species}.pdf", bbox_inches="tight")

### 101346

In [None]:
species = "101346"


print(motu_taxonomy.loc[species])

In [None]:
strain_fit = (
    sf.data.World.load(
        f"data/group/een/species/sp-{species}/r.proc.gtpro.sfacts-fit.world.nc"
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={"CF_11": "CF_15", "CF_15": "CF_11"})
    .drop_low_abundance_strains(0.2)
    .rename_coords(strain=str)
)
print(strain_fit.sizes)

# Genotype similarity ordered palette:
strain_linkage = strain_fit.genotype.linkage(optimal_ordering=True)
strain_order = list(
    linkage_order(
        strain_linkage,
        strain_fit.strain.values,
    )
)
strain_order.remove("-1")  # Drop "other" strain.
strain_order.append("-1")  # Add to end of list
strain_palette = lib.plot.construct_ordered_palette(
    strain_order,
    cm="rainbow",
    extend={"-1": "silver"},
)

sf.evaluation.metagenotype_error2(strain_fit, discretized=True)[0]

In [None]:
np.random.seed(0)

sample_linkage = strain_fit.unifrac_linkage(optimal_ordering=True)
position_ss = strain_fit.random_sample(
    position=min(strain_fit.sizes["position"], 1000)
).position

sf.plot.plot_metagenotype(
    strain_fit.sel(position=position_ss), col_linkage_func=lambda w: sample_linkage
)
sf.plot.plot_community(strain_fit, col_linkage_func=lambda w: sample_linkage)

In [None]:
more_colors_strain_palette = lib.plot.construct_ordered_palette(
    strain_order,  # Linkage order, I believe
    cm="rainbow",
    extend={"-1": "silver"},
    desaturate_levels=[1.0, 0.7, 0.4],
)


sample.timepoint.map(rename_timepoints_for_ts)

fig, axs = plt.subplots(
    2,
    10,
    figsize=(10 * 2.7, 2 * 1.5),
    squeeze=False,
    sharey=True,
    gridspec_kw=dict(hspace=1.8, wspace=0),
)


for subject, ax in zip(subject_order, axs.flatten()):
    subject_comm_sample_list = list(
        set(idxwhere(sample.subject_id == subject)) & set(strain_fit.sample.values)
    )

    try:
        subject_comm = (
            strain_fit.sel(sample=subject_comm_sample_list)
            .drop_low_abundance_strains(0.2, agg_strain_coord="-1")
            .community.to_pandas()
        )
    except ValueError:
        subject_comm = pd.DataFrame([], columns=["-1"])

    sample_list = sample.sort_values("collection_date_relative_een_end")[
        lambda x: (x.subject_id == subject)
        & (x.sample_type == "human")
        & (x.index.isin(subject_comm.index))
    ].index

    if len(sample_list) == 0:
        lib.plot.hide_axes_and_spines(ax=ax)
        continue

    d = (
        sample.reindex(sample_list)
        # .dropna(subset=["collection_date_relative_een_end"])
        # .sort_values("collection_date_relative_een_end")
        .assign(
            t=lambda x: range(len(x)),
        )
    ).join(subject_comm)
    # d.loc[d.index[:num_offset_samples], 't'] -= 0.7  # Offset width

    plot_stacked_barplot(
        data=d,
        x_var="t",
        order=[s for s in strain_order if s in subject_comm.columns],
        palette=more_colors_strain_palette,
        ax=ax,
        width=0.8,
        lw=0,
    )

    ax.set_title(subject)
    ax.set_xticklabels(
        d.timepoint.map(rename_timepoints_for_ts),
        fontsize=12,
    )
    ax.set_aspect(3, anchor="NW")
    ax.set_ylim(0, 1.0)
    lib.plot.rotate_xticklabels(rotation=90, ax=ax, ha="center")
    ax.set_yticks(np.linspace(0, 1.0, num=3))
    ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, symbol="%"))
    ax.set_xlim(d.t.min() - 0.5, d.t.max() + 0.5)
    ax.spines[["right", "top"]].set_visible(False)
    # ax.legend(bbox_to_anchor=(1, 1), ncols=2)

fig.savefig(f"fig/een.strain_timeseries.{species}.pdf", bbox_inches="tight")

## Species Summary and Scoring

In [None]:
d0 = (
    motu_enrichment_results.join(per_species_obs_coefs)
    .join(species_turnover_analysis_details)
    .assign(
        uhgg_taxonomy=motu_lineage_string,
        zotu_match=reciprocal_hits_filtered.reset_index().groupby('species_id').zotu.apply(', '.join),
        subject_prevalence=(motu_prevalence_by_subject > 0).mean(),
        mean_rabund=lambda x: (x.mean_EEN + x.mean_PostEEN) / 2,
        indicator_score=lambda x: np.abs(x.log2_ratio)
        * x.num_pairs
        * x.mean_rabund
        * x.overall_mean_diss,
    )
    .rename(
        columns=dict(
            log2_ratio="species_log2_fold_change",
            mean_rabund="species_overall_mean_relative_abundance",
            num_pairs="num_intrasubject_sample_pairs",
            overall_mean_diss="mean_pairwise_braycurtis_dissimilarity",
            zotu_match="zotu_match",
        )
    )
)

# # Save statistics
# (
#     d0.sort_values("indicator_score", ascending=False)[
#         lambda x: x.indicator_score > 0.1
#     ][
#         [
#             "indicator_score",
#             "species_log2_fold_change",
#             "species_overall_mean_relative_abundance",
#             "mean_pairwise_braycurtis_dissimilarity",
#             "num_intrasubject_sample_pairs",
#             "uhgg_taxonomy",
#             "zotu_match",
#         ]
#     ]
#     .round(3)
#     .to_csv("fig/een_turnover_stats.tsv", sep="\t")
# )

d1 = d0[
    lambda x: (x.num_intrasubject_sample_pairs > 20)
    & (~x.species_overall_mean_relative_abundance.isna())
].assign(zorder=0)


focal_species_style_map = {
    "101493": ("E. bolteae", (-10, 10)),
    "101386": ("E. clostridioforme", (-20, -30)),
    "102506": ("E. coli", (-10, 10)),
    "101346": ("B. uniformis", (-40, 10)),
}

d1.loc[focal_species_style_map.keys(), "zorder"] = 1

fig, ax = plt.subplots(figsize=(6, 4))
plt.scatter(
    "species_log2_fold_change",
    "mean_pairwise_braycurtis_dissimilarity",
    data=d1[d1.zorder == 0],
    c="indicator_score",
    zorder=0,
    norm=mpl.colors.SymLogNorm(linthresh=0.1, linscale=0.5),
    s=50,
    edgecolor="grey",
    alpha=0.7,
)

plt.scatter(
    "species_log2_fold_change",
    "mean_pairwise_braycurtis_dissimilarity",
    data=d1[d1.zorder == 1],
    c="indicator_score",
    zorder=1,
    # norm=mpl.colors.SymLogNorm(linthresh=0.1, linscale=0.5),
    s=70,
    edgecolor="grey",
    alpha=0.7,
)
# plt.ylim(-0.5, 1.5)
cbar = plt.colorbar(label="Score", alpha=1.0)
cbar.solids.set(alpha=1)

for _species, (_name, (textx, texty)) in focal_species_style_map.items():
    # plt.scatter(
    #     "species_log2_fold_change",
    #     "mean_pairwise_braycurtis_dissimilarity",
    #     data=d1.loc[[_species]],
    #     edgecolor=_color,
    #     facecolor="none",
    #     s=200,
    # )
    plt.annotate(
        _name,
        xy=d1.loc[
            _species,
            ["species_log2_fold_change", "mean_pairwise_braycurtis_dissimilarity"],
        ].values,
        xytext=(textx, texty),
        textcoords="offset points",
        # rotation=rotation,
        fontstyle='italic',
        arrowprops=dict(arrowstyle="->", color='darkred', linewidth=1.2),
        ha='right',
        
    )

plt.xlabel("Mean Log2(Fold-change)")
plt.ylabel("Mean Turnover")
# plt.yscale('symlog', linthresh=1e-2)
# plt.xlim(-13, 6)
plt.xticks([-12, -9, -6, -3, 0, 3, 6, 9, 12])
plt.savefig("fig/een_turnover_stats.pdf", bbox_inches="tight")


print(
    sp.stats.spearmanr(
        d1.species_log2_fold_change, d1.mean_pairwise_braycurtis_dissimilarity
    )
)

#### Supplementary Table 4

In [None]:
species_transition_test_pvalue

In [None]:
d0 = (
    motu_enrichment_results.join(per_species_obs_coefs)
    .join(species_turnover_analysis_details)
    .assign(
        zotu_match=reciprocal_hits_filtered.reset_index().groupby('species_id').zotu.apply(', '.join),
        subject_prevalence=(motu_prevalence_by_subject > 0).mean(),
        mean_rabund=lambda x: (x.mean_EEN + x.mean_PostEEN) / 2,
        indicator_score=lambda x: np.abs(x.log2_ratio)
        * x.mean_rabund
        * x.num_pairs
        * x.overall_mean_diss,
        transition_turnover_effect=transition_stats_obs.transition_vs_mean,
        transition_turnover_pvalue=species_transition_test_pvalue,
    )
    .rename(
        columns=dict(
            log2_ratio="species_log2_fold_change",
            pvalue="species_pvalue",
            mean_rabund="species_overall_mean_relative_abundance",
            num_pairs="num_intrasubject_sample_pairs",
            overall_mean_diss="mean_pairwise_braycurtis_dissimilarity",
            zotu_match="zotu_match",

        )
    )
)

d1 = (
    d0.dropna(subset=['indicator_score']).sort_values("indicator_score", ascending=False)
    .assign(
        indicator_score=lambda x: x.indicator_score.round(3),
        species_log2_fold_change=lambda x: x.species_log2_fold_change.round(1),
        species_pvalue=lambda x: x.species_pvalue.round(4),
        species_overall_mean_relative_abundance=lambda x: x.species_overall_mean_relative_abundance.round(4),
        num_intrasubject_sample_pairs=lambda x: x.num_intrasubject_sample_pairs.astype(int),
        mean_pairwise_braycurtis_dissimilarity=lambda x: x.mean_pairwise_braycurtis_dissimilarity.round(2),
        transition_turnover_effect=lambda x: x.transition_turnover_effect.round(2),
        transition_turnover_pvalue=lambda x: x.transition_turnover_pvalue.round(4),
        # zotu_match=lambda x: x.zotu_match.fillna(""),
        uhgg_taxonomy=motu_taxonomy.s__,
        uhgg_url=lambda x: "https://www.ebi.ac.uk/metagenomics/genomes/MGYG0000" + x.index.to_series().str[1:],
        ezbc_id="",
        network_association="",
    )
    [
        [
            "indicator_score",
            "species_log2_fold_change",
            "species_pvalue",
            "species_overall_mean_relative_abundance",
            "num_intrasubject_sample_pairs",
            "mean_pairwise_braycurtis_dissimilarity",
            "transition_turnover_effect",
            "transition_turnover_pvalue",
            "zotu_match",
            "uhgg_taxonomy",
            "uhgg_url",
            "ezbc_id",
            "network_association",
        ]
    ]
)

d1[lambda x: x.indicator_score > 0.1].to_csv("fig/een_supplementary_table_s4.tsv", sep="\t")
d1[lambda x: x.indicator_score > 0.1]

In [None]:
d1[lambda x: x.indicator_score > 0.1].zotu_match.isna().value_counts()

In [None]:
d1.loc[['101493', '101386', '102506', '101346']]

In [None]:
d1[lambda x: (x.indicator_score > 0.1) & (x.transition_turnover_pvalue < 0.05) & (x.transition_turnover_effect > 0)]