# Preamble

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

## Imports

In [None]:
import os
import subprocess
import time
from itertools import chain, product
from tempfile import mkstemp
from warnings import filterwarnings

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import sfacts as sf
import statsmodels.formula.api as smf
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable

# from fastcluster import linkage
from scipy.cluster.hierarchy import linkage
from scipy.spatial.distance import pdist, squareform
from statsmodels.graphics.regressionplots import influence_plot
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
import lib.thisproject.data
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

In [None]:
def linkage_order(linkage, labels):
    return labels[sp.cluster.hierarchy.to_tree(linkage).pre_order(lambda x: x.id)]


def is_prime(n):
    if n <= 1:
        return False
    for i in range(2, int(n**0.5) + 1):
        if n % i == 0:
            return False
    return True


def iterate_primes_up_to(n, return_index=False):
    n = int(np.ceil(n))
    idx = 0
    for i in range(n):
        if is_prime(i):
            if return_index:
                yield (idx, i)
            else:
                yield i
            idx += 1


def maximally_shuffled_order(sorted_order):
    n = len(sorted_order)
    primes_list = list(iterate_primes_up_to(np.sqrt(n)))
    table = pd.DataFrame(np.arange(n), index=sorted_order, columns=["original_order"])
    for prime in primes_list:
        table[prime] = table.original_order % prime
    table.sort_values(primes_list).original_order.values
    table = table.assign(new_order=table.sort_values(primes_list).original_order.values)
    z = table.sort_values("new_order").original_order.values
    table["delta"] = [np.nan] + list(z[1:] - z[:-1])
    return table.sort_values("new_order").index.to_list()

In [None]:
def _label_experiment_sample(x):
    if x.sample_type == "human":
        label = f"[{x.name}] {x.collection_date_relative_een_end} {x.diet_or_media}"
    elif x.sample_type in ["Fermenter_inoculum"]:
        label = f"[{x.name}] {x.source_samples} inoc {x.diet_or_media}"
    elif x.sample_type in ["Fermenter"]:
        # label = f"[{x.name}] {x.source_samples} frmnt {x.diet_or_media}"
        label = f"[{x.name}] {x.source_samples} {x.diet_or_media}"
    elif x.sample_type in ["mouse"]:
        if x.status_mouse_inflamed == 'Inflamed':
            # label = f"[{x.name}] {x.source_samples} 🐭 {x.mouse_genotype} {x.diet_or_media} inflam"
            label = f"[{x.name}] {x.source_samples} {x.diet_or_media} inflam"
        elif x.status_mouse_inflamed == 'not_Inflamed':
            # label = f"[{x.name}] {x.source_samples} 🐭 {x.mouse_genotype} {x.diet_or_media} not_inf"
            label = f"[{x.name}] {x.source_samples} {x.diet_or_media} not_inf"
        else:
            raise ValueError(f"sample type {x.status_mouse_inflamed} not understood")
    else:
        raise ValueError(f"sample type {x.sample_type} not understood")
    return label

In [None]:
sns.set_context("talk")
plt.rcParams["figure.dpi"] = 50

# Prepare Metadata

In [None]:
pair_type_palette = {"Transition": "plum", "EEN": "pink", "PostEEN": "lightblue"}

diet_palette = {
    "EEN": "lightgreen",
    "PostEEN": "lightblue",
    "InVitro": "plum",
    "PreEEN": "lightpink",
}

subject_order = [
    "A",
    "B",
    "H",
    "C",
    "D",
    "E",
    "F",
    "G",
    "K",
    "L",
    "M",
    "N",
    "O",
    "P",
    "Q",
    "R",
    "S",
    "T",
    "U",
]

# NOTE: Requires a dummy value because I want exactly 20 items.
subject_palette = lib.plot.construct_ordered_palette(
    subject_order + [f"dummy{i}" for i in range(20 - len(subject_order))], cm="tab20"
)
subject_palette["X"] = "black"
pair_type_order = ["EEN", "Transition", "PostEEN"]
pair_type_marker_palette = {"EEN": "s", "Transition": ">", "PostEEN": "o"}
pair_type_linestyle_palette = {"EEN": ":", "Transition": "-.", "PostEEN": "-"}

In [None]:
sample = (
    pd.read_table("meta/een-mgen/sample.tsv")
    .assign(
        label=lambda x: x[
            ["collection_date_relative_een_end", "diet_or_media", "sample_id"]
        ].apply(tuple, axis=1)
    )
    .set_index("sample_id")
    .assign(full_label=lambda d: d.apply(_label_experiment_sample, axis=1))
)
subject = pd.read_table("meta/een-mgen/subject.tsv", index_col="subject_id")

In [None]:
sample.full_label

# Prepare Data

In [None]:
rotu_counts = pd.read_table(
    "data/group/een/a.proc.zotu_counts.tsv", index_col="#OTU ID"
).rename_axis(index="zotu", columns="sample_id")
rotu_taxonomy = rotu_counts.taxonomy
rotu_counts = rotu_counts.drop(columns=["taxonomy"]).T
rotu_rabund = rotu_counts.divide(rotu_counts.sum(1), axis=0)

sample_rotu_bc_linkage = sp.cluster.hierarchy.linkage(
    rotu_rabund, method="average", metric="braycurtis", optimal_ordering=True
)

# Species Enrichment Analysis

In [None]:
def enrichment_test(d):
    try:
        res = sp.stats.wilcoxon(d["EEN"], d["PostEEN"])
    except ValueError:
        res = (np.nan, np.nan)
    log2_ratio = np.log2(d["PostEEN"] / d["EEN"])
    return pd.Series(
        [log2_ratio.mean(), d["EEN"].mean(), d["PostEEN"].mean(), res[1]],
        index=["log2_ratio", "mean_EEN", "mean_PostEEN", "pvalue"],
    )

In [None]:
d = (
    rotu_rabund.apply(lambda x: x + x.replace({0: np.inf}).min())
    .join(sample[["subject_id", "diet_or_media"]])
    .groupby(["subject_id", "diet_or_media"])
    .mean()
    .stack()
    .unstack("diet_or_media")[["EEN", "PostEEN"]]
    .dropna()
    .assign(log2_ratio=lambda x: np.log2(x["PostEEN"] / x["EEN"]))
    .rename_axis(index=["subject_id", "rotu_id"])
    .groupby(level="rotu_id")
    .apply(enrichment_test)
)
d.sort_values("mean_PostEEN", ascending=False).head(20)
# fig, ax = plt.subplots()
# print(d.log2_ratio.mean())
# print(sp.stats.wilcoxon(d['PostEEN'], d['EEN']))
# ax.hist(d.log2_ratio, bins=20)

# Strain Time-series

## E. coli (Zotu4 / 102506)

In [None]:
motu_id = "102506"
rotu_id = "Zotu4"
drop_strains_thresh = 0.5
ylinthresh = 1e-4

In [None]:
sf_fit = (
    sf.data.World.load(
        f"data/group/een/species/sp-{motu_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts48-s85-seed0.world.nc"
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={"CF_11": "CF_15", "CF_15": "CF_11"})
    .drop_low_abundance_strains(0.01)
    .rename_coords(strain=str)
)
mgtp_error = sf.evaluation.metagenotype_error2(sf_fit, discretized=False)[1]
entrp_error = sf.evaluation.metagenotype_entropy_error(
    sf_fit, discretized=False, p=1, montecarlo_draws=10
)[1]
comm_entrp = sf_fit.community.entropy().to_series()
high_mgtp_error = mgtp_error >= 0.1
high_entrp_error = entrp_error >= 0.2
high_comm_entrp = comm_entrp >= 1.5

# Genotype similarity ordered palette:
strain_linkage = sf_fit.genotype.linkage(optimal_ordering=True)
strain_order = list(
    linkage_order(
        strain_linkage,
        sf_fit.strain.values,
    )
)
strain_order.remove("-1")  # Drop "other" strain.
strain_palette = lib.plot.construct_ordered_palette(
    strain_order,
    cm="rainbow",
)

sf.plot.plot_community(
    sf_fit,
    scalex=0.4,
    scaley=0.6,
    row_linkage_func=lambda w: strain_linkage,
    row_colors=sf_fit.strain.to_series().map(strain_palette),
)

In [None]:
trnsfm_x = lambda x: np.sign(x) * np.sqrt(np.abs(x))
bar_width = 1.0

subject_list = subject_order  # [:3]

fig, axs = plt.subplots(
    nrows=len(subject_list),
    figsize=(10, 4 * len(subject_list)),
    squeeze=False,
    sharex=True,
    sharey=True,
)
for subject_id, ax in zip(subject_list, axs.flatten()):
    ax.set_title(subject_id)
    all_sample_list = (
        sample[lambda x: x.subject_id == subject_id].index.to_series().pipe(list)
    )
    subject_sample_order = (
        sample[lambda x: (x.subject_id == subject_id) & x.sample_type.isin(["human"])]
        .sort_values(["collection_date_relative_een_end"])
        .index.to_series()
        .pipe(list)
    )
    # Cull low abundance strains for each subject

    strain_frac_sample_list = list(set(all_sample_list) & set(sf_fit.sample.values))
    if len(strain_frac_sample_list) == 0:
        print(f"No strain analysis for {subject_id}.")
        comm = []
        _strain_order = []
    else:
        w = (
            sf_fit.sel(sample=strain_frac_sample_list)
            .drop_low_abundance_strains(drop_strains_thresh)
            .rename_coords(strain=str)
        )
        _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
        comm = w.community.to_pandas()

    d0 = (
        sample.loc[subject_sample_order, ["collection_date_relative_een_end"]]
        .assign(
            rotu_rabund=rotu_rabund[rotu_id],
            trnsfm_collection_date_relative_een_end=lambda x: x.collection_date_relative_een_end.apply(
                trnsfm_x
            ),
        )
        .join(comm)
    )

    ax.plot(
        "trnsfm_collection_date_relative_een_end",
        "rotu_rabund",
        data=d0,
        marker="o",
        label="__nolegend__",
        color="k",
        markersize=5,
    )

    # Plot stacked barplot
    ax1 = ax.twinx()
    top_last = 0
    for strain, d1 in d0[_strain_order].T.iterrows():
        ax1.bar(
            x=d0.loc[d1.index].trnsfm_collection_date_relative_een_end,
            height=d1,
            bottom=top_last,
            width=bar_width,
            alpha=1.0,
            color=strain_palette[strain],
            edgecolor="k",
            lw=1,
            label=strain,
        )
        top_last += d1
    ax1.set_yticks([])

    # Put strains behind points:
    ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
    ax.patch.set_visible(False)  # hide the 'canvas'
    ax1.patch.set_visible(True)  # show the 'canvas'

    # Start and end of EEN
    ax.axvline(0, lw=1, linestyle="--", color="k")
    ax.axvline(
        trnsfm_x(subject.een_start_date_relative_een_end.loc[subject_id]),
        lw=1,
        linestyle="--",
        color="k",
    )
    ax1.legend(bbox_to_anchor=(1, 1), ncols=1)
    xtick_pos = np.array([-100, -50, -10, 0, 10, 50, 100, 300])
    ax.set_xticks(trnsfm_x(xtick_pos))
    ax.set_xticklabels(xtick_pos)
    ax.set_yscale("symlog", linthresh=ylinthresh)
    ax.set_ylim(0, 1.0)
    ax1.set_ylim(0, 1.0)
    lib.plot.rotate_xticklabels(ax=ax)

In [None]:
d = (
    rotu_rabund.join(sample)
    .groupby(["subject_id", "diet_or_media"])[rotu_id]
    .mean()
    .unstack()[["EEN", "PostEEN"]]
    .dropna()
    .assign(log2_ratio=lambda x: np.log2(x["PostEEN"] / x["EEN"]))
)

print(d.log2_ratio.mean())
print(sp.stats.wilcoxon(d["PostEEN"], d["EEN"]))
plt.hist(d.log2_ratio, bins=20)

In [None]:
max_c_value = np.abs(d.log2_ratio).max()

fig, ax = plt.subplots(figsize=(3, 8))

for subject_id, (een_rabund, post_rabund, log2_ratio, c) in d.assign(
    c=lambda x: ((x.log2_ratio / max_c_value) + 1) / 2
).iterrows():
    ax.plot([0, 1], [een_rabund, post_rabund], c=mpl.cm.coolwarm(c), lw=4)
ax.set_yscale("log")
ax.set_xticks([0, 1])
ax.set_xticklabels(["EEN", "PostEEN"])
ax.set_xlim(-0.1, 1.1)
ax.set_ylim(1e-5, 1.)

In [None]:
d0 = (
    sample.loc[
        lambda x: (
            True
            # x.index.isin(sf_fit.sample.values)
            & x.sample_type.isin(["human", "Fermenter", "mouse"])
            & x.subject_id.isin(["A", "B", "H"])
            # & (x.sample_type == "Fermenter")
        ),
        [
            "subject_id",
            "collection_date_relative_een_end",
            "sample_type",
            "diet_or_media",
            "mouse_genotype",
            "source_samples",
            "status_mouse_inflamed",
            "full_label",
        ],
    ]
    .sort_values(
        [
            "subject_id",
            "collection_date_relative_een_end",
            "sample_type",
            "diet_or_media",
            "mouse_genotype",
            "source_samples",
            "status_mouse_inflamed",
        ]
    )
    .assign(
        rotu_rabund=rotu_rabund[rotu_id],
    )
)

sample_type_order = ["human", "Fermenter", "mouse"]
subject_order = ["A", "B", "H"]


_grid_sample_counts = (
    d0[["subject_id", "sample_type"]]
    .value_counts()
    .unstack()
    .reindex(columns=sample_type_order)
)
fig, axs = plt.subplots(
    *_grid_sample_counts.shape,
    figsize=(90, 15),
    width_ratios=_grid_sample_counts.max().values,
    sharey=True,
    gridspec_kw=dict(wspace=0.1, hspace=3)
)

for subject_id, ax_row in zip(subject_order, axs):
    for sample_type, ax in zip(sample_type_order, ax_row):
        d1 = d0[lambda x: (x.sample_type == sample_type) & (x.subject_id == subject_id)].assign(xpos=lambda x: np.arange(len(x.index)))
        ax.scatter("xpos", "rotu_rabund", data=d1, color='k', s=10, label='__nolegend__')
        # ax.set_aspect(700, adjustable="datalim", anchor="NW")
        ax.set_ylim(-1e-5, 1)
        ax.set_yscale("symlog", linthresh=1e-4, linscale=0.1)
        ax.set_xlim(-0.5, _grid_sample_counts[sample_type].max())
        ax.set_xticks(d1.xpos)
        ax.set_xticklabels(d1.full_label)


        strain_frac_sample_list = list(set(d1.index) & set(sf_fit.sample.values))
        if len(strain_frac_sample_list) == 0:
            print(f"No strain analysis for {subject_id}.")
            comm = []
            _strain_order = []
        else:
            w = (
                sf_fit.sel(sample=strain_frac_sample_list)
                .drop_low_abundance_strains(drop_strains_thresh)
                .rename_coords(strain=str)
            )
            _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
            comm = w.community.to_pandas()
        d2 = d1.join(comm)
        # Plot stacked barplot
        ax1 = ax.twinx()
        top_last = 0
        for strain in _strain_order:
            ax1.bar(
                x='xpos',
                height=strain,
                data=d2,
                bottom=top_last,
                width=bar_width,
                alpha=1.0,
                color=strain_palette[strain],
                edgecolor="k",
                lw=1,
                label='__nolegend__',
            )
            top_last += d2[strain]
            ax.scatter([], [], color=strain_palette[strain], label=strain, marker='s', s=80)
        ax1.set_yticks([])
        # Put strains behind points:
        ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
        ax.patch.set_visible(False)  # hide the 'canvas'
        ax1.patch.set_visible(True)  # show the 'canvas'

        ax1.set_ylim(0, 1)
        lib.plot.rotate_xticklabels(ax=ax)
        if sample_type == 'human':
            ax.legend(loc='upper right')


# fig.tight_layout()

# axs[0, 0].plot([0, 100], [0, 1])

# axs[0, 1].plot([0, 10], [0, 1])
# axs[0, 1].set_aspect(10)

# axs[0, 2].plot([0, 20], [0, 1])
# axs[0, 2].set_aspect(5)

In [None]:
_grid_sample_counts.max().values

In [None]:



for subject_id, sample_type in product(["A", "B", "H"], ["Fermenter", "mouse"]):
    all_sample_list = (
        sample[lambda x: (x.sample_type == sample_type) & (x.subject_id == subject_id)]
        .index.to_series()
        .pipe(list)
    )
    subject_sample_order = (
        sample[
            lambda x: (x.sample_type == sample_type)
            & (x.subject_id == subject_id)
            & (x.index.isin(sf_fit.sample.values))
        ]
        .sort_values(
            [
                "sample_type",
                "diet_or_media",
                "mouse_genotype",
                "source_samples",
                "status_mouse_inflamed",
            ]
        )
        .index.to_series()
        .pipe(list)
    )
    # Cull low abundance strains for each subject

    strain_frac_sample_list = list(set(all_sample_list) & set(sf_fit.sample.values))
    if len(strain_frac_sample_list) == 0:
        print(f"No strain analysis for {subject_id}.")
        comm = []
        _strain_order = []
    else:
        w = (
            sf_fit.sel(sample=strain_frac_sample_list)
            .drop_low_abundance_strains(drop_strains_thresh)
            .rename_coords(strain=str)
        )
        _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
        comm = w.community.to_pandas()

    d0 = (
        sample.loc[
            subject_sample_order,
            [
                "sample_type",
                "diet_or_media",
                "mouse_genotype",
                "source_samples",
                "status_mouse_inflamed",
                "full_label",
            ],
        ]
        .assign(
            rotu_rabund=rotu_rabund[rotu_id],
        )
        .join(comm)
        .assign(xpos=lambda x: np.arange(len(x.index)))
    )

    fig, ax = plt.subplots(
        figsize=(0.4 * len(strain_frac_sample_list), 4),
    )

    ax.plot(
        "rotu_rabund",
        data=d0,
        marker="o",
        linestyle="",
        label="__nolegend__",
        color="k",
        markersize=5,
    )

    # Plot stacked barplot
    ax1 = ax.twinx()
    top_last = 0
    for strain, d1 in d0[_strain_order].T.iterrows():
        ax1.bar(
            x=d0.loc[d1.index].xpos,
            height=d1,
            width=1,
            bottom=top_last,
            alpha=1.0,
            color=strain_palette[strain],
            edgecolor="k",
            lw=1,
            label=strain,
        )
        top_last += d1
    ax1.set_yticks([])

    # Put strains behind points:
    ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
    ax.patch.set_visible(False)  # hide the 'canvas'
    ax1.patch.set_visible(True)  # show the 'canvas'

    # # Start and end of EEN
    # ax.axvline(0, lw=1, linestyle="--", color="k")
    # ax.axvline(
    #     trnsfm_x(subject.een_start_date_relative_een_end.loc[subject_id]),
    #     lw=1,
    #     linestyle="--",
    #     color="k",
    # )
    ax1.legend(bbox_to_anchor=(1, 1), ncols=1)

    ax.set_title(subject_id)
    ax.set_xticklabels(d0.full_label)

    # xtick_pos = np.array([-100, -50, -10, 0, 10, 50, 100, 300])
    # ax.set_xticks(trnsfm_x(xtick_pos))
    # ax.set_xticklabels(xtick_pos)
    ax.set_yscale("symlog", linthresh=ylinthresh)
    ax.set_ylim(0, 1.0)
    ax1.set_ylim(0, 1.0)
    ax.set_xlim(-0.5, len(strain_frac_sample_list) - 0.5)
    lib.plot.rotate_xticklabels(ax=ax)

In [None]:
(
    sample.loc[
        lambda x: (x.index.isin(sf_fit.sample.values)
        & x.sample_type.isin(['mouse'])
        & x.subject_id.isin(["A", "B", "H"])
        # & (x.sample_type == "Fermenter")
                  ),
        [
            "subject_id",
            "collection_date_relative_een_end",
            "sample_type",
            "diet_or_media",
            "mouse_genotype",
            "source_samples",
            "status_mouse_inflamed",
        ],
    ]
    .sort_values(
        [
            "subject_id",
            "collection_date_relative_een_end",
            "sample_type",
            "diet_or_media",
            "mouse_genotype",
            "source_samples",
            "status_mouse_inflamed",
        ]
    )
    .assign(
        rotu_rabund=rotu_rabund[rotu_id],
    )
    # .join(comm)
    .assign(xpos=lambda x: np.arange(len(x.index)))
)

## E. lenta (Zotu172 / 102544)

In [None]:
motu_id = "102544"
rotu_id = "Zotu172"
drop_strains_thresh = 0.5
ylinthresh = 1e-4

In [None]:
sf_fit = (
    sf.data.World.load(
        f"data/group/een/species/sp-{motu_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts48-s85-seed0.world.nc"
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={"CF_11": "CF_15", "CF_15": "CF_11"})
    .drop_low_abundance_strains(0.01)
    .rename_coords(strain=str)
)
mgtp_error = sf.evaluation.metagenotype_error2(sf_fit, discretized=False)[1]
entrp_error = sf.evaluation.metagenotype_entropy_error(
    sf_fit, discretized=False, p=1, montecarlo_draws=10
)[1]
comm_entrp = sf_fit.community.entropy().to_series()
high_mgtp_error = mgtp_error >= 0.1
high_entrp_error = entrp_error >= 0.2
high_comm_entrp = comm_entrp >= 1.5

# Genotype similarity ordered palette:
strain_linkage = sf_fit.genotype.linkage(optimal_ordering=True)
strain_order = list(
    linkage_order(
        strain_linkage,
        sf_fit.strain.values,
    )
)
strain_order.remove("-1")  # Drop "other" strain.
strain_palette = lib.plot.construct_ordered_palette(
    strain_order,
    cm="rainbow",
)

sf.plot.plot_community(
    sf_fit,
    scalex=0.4,
    scaley=0.6,
    row_linkage_func=lambda w: strain_linkage,
    row_colors=sf_fit.strain.to_series().map(strain_palette),
)

In [None]:
trnsfm_x = lambda x: np.sign(x) * np.sqrt(np.abs(x))
bar_width = 1.0

subject_list = subject_order  # [:3]

fig, axs = plt.subplots(
    nrows=len(subject_list),
    figsize=(10, 4 * len(subject_list)),
    squeeze=False,
    sharex=True,
    sharey=True,
)
for subject_id, ax in zip(subject_list, axs.flatten()):
    ax.set_title(subject_id)
    all_sample_list = (
        sample[lambda x: x.subject_id == subject_id].index.to_series().pipe(list)
    )
    subject_sample_order = (
        sample[lambda x: (x.subject_id == subject_id) & x.sample_type.isin(["human"])]
        .sort_values(["collection_date_relative_een_end"])
        .index.to_series()
        .pipe(list)
    )
    # Cull low abundance strains for each subject

    strain_frac_sample_list = list(set(all_sample_list) & set(sf_fit.sample.values))
    if len(strain_frac_sample_list) == 0:
        print(f"No strain analysis for {subject_id}.")
        comm = []
        _strain_order = []
    else:
        w = (
            sf_fit.sel(sample=strain_frac_sample_list)
            .drop_low_abundance_strains(drop_strains_thresh)
            .rename_coords(strain=str)
        )
        _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
        comm = w.community.to_pandas()

    d0 = (
        sample.loc[subject_sample_order, ["collection_date_relative_een_end"]]
        .assign(
            rotu_rabund=rotu_rabund[rotu_id],
            trnsfm_collection_date_relative_een_end=lambda x: x.collection_date_relative_een_end.apply(
                trnsfm_x
            ),
        )
        .join(comm)
    )

    ax.plot(
        "trnsfm_collection_date_relative_een_end",
        "rotu_rabund",
        data=d0,
        marker="o",
        label="__nolegend__",
        color="k",
        markersize=5,
    )

    # Plot stacked barplot
    ax1 = ax.twinx()
    top_last = 0
    for strain, d1 in d0[_strain_order].T.iterrows():
        ax1.bar(
            x=d0.loc[d1.index].trnsfm_collection_date_relative_een_end,
            height=d1,
            bottom=top_last,
            width=bar_width,
            alpha=1.0,
            color=strain_palette[strain],
            edgecolor="k",
            lw=1,
            label=strain,
        )
        top_last += d1
    ax1.set_yticks([])

    # Put strains behind points:
    ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
    ax.patch.set_visible(False)  # hide the 'canvas'
    ax1.patch.set_visible(True)  # show the 'canvas'

    # Start and end of EEN
    ax.axvline(0, lw=1, linestyle="--", color="k")
    ax.axvline(
        trnsfm_x(subject.een_start_date_relative_een_end.loc[subject_id]),
        lw=1,
        linestyle="--",
        color="k",
    )
    ax1.legend(bbox_to_anchor=(1, 1), ncols=1)
    xtick_pos = np.array([-100, -50, -10, 0, 10, 50, 100, 300])
    ax.set_xticks(trnsfm_x(xtick_pos))
    ax.set_xticklabels(xtick_pos)
    ax.set_yscale("symlog", linthresh=ylinthresh)
    ax.set_ylim(0, 1.0)
    ax1.set_ylim(0, 1.0)
    lib.plot.rotate_xticklabels(ax=ax)

In [None]:
trnsfm_x = lambda x: np.sign(x) * np.sqrt(np.abs(x))
bar_width = 1.0

subject_list = ['A', 'B', 'H', 'N', 'M', 'S']

fig, axs = plt.subplots(
    nrows=len(subject_list) // 2,
    ncols=2,
    figsize=(10 * 2, 4 * len(subject_list) // 2),
    squeeze=False,
    sharex=True,
    sharey=True,
)
for subject_id, ax in zip(subject_list, axs.flatten()):
    ax.set_title(subject_id)
    all_sample_list = (
        sample[lambda x: x.subject_id == subject_id].index.to_series().pipe(list)
    )
    subject_sample_order = (
        sample[lambda x: (x.subject_id == subject_id) & x.sample_type.isin(["human"])]
        .sort_values(["collection_date_relative_een_end"])
        .index.to_series()
        .pipe(list)
    )
    # Cull low abundance strains for each subject

    strain_frac_sample_list = list(set(all_sample_list) & set(sf_fit.sample.values))
    if len(strain_frac_sample_list) == 0:
        print(f"No strain analysis for {subject_id}.")
        comm = []
        _strain_order = []
    else:
        w = (
            sf_fit.sel(sample=strain_frac_sample_list)
            .drop_low_abundance_strains(drop_strains_thresh)
            .rename_coords(strain=str)
        )
        _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
        comm = w.community.to_pandas()

    d0 = (
        sample.loc[subject_sample_order, ["collection_date_relative_een_end"]]
        .assign(
            rotu_rabund=rotu_rabund[rotu_id],
            trnsfm_collection_date_relative_een_end=lambda x: x.collection_date_relative_een_end.apply(
                trnsfm_x
            ),
        )
        .join(comm)
    )

    ax.plot(
        "trnsfm_collection_date_relative_een_end",
        "rotu_rabund",
        data=d0,
        marker="o",
        label="__nolegend__",
        color="k",
        markersize=5,
    )

    # Plot stacked barplot
    ax1 = ax.twinx()
    top_last = 0
    for strain, d1 in d0[_strain_order].T.iterrows():
        ax1.bar(
            x=d0.loc[d1.index].trnsfm_collection_date_relative_een_end,
            height=d1,
            bottom=top_last,
            width=bar_width,
            alpha=1.0,
            color=strain_palette[strain],
            edgecolor="k",
            lw=1,
            label=strain,
        )
        top_last += d1
    ax1.set_yticks([])

    # Put strains behind points:
    ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
    ax.patch.set_visible(False)  # hide the 'canvas'
    ax1.patch.set_visible(True)  # show the 'canvas'

    # Start and end of EEN
    ax.axvline(0, lw=1, linestyle="--", color="k")
    ax.axvline(
        trnsfm_x(subject.een_start_date_relative_een_end.loc[subject_id]),
        lw=1,
        linestyle="--",
        color="k",
    )
    # ax1.legend()
    xtick_pos = np.array([-100, -50, -10, 0, 10, 50, 100, 300])
    ax.set_xticks(trnsfm_x(xtick_pos))
    ax.set_xticklabels(xtick_pos)
    ax.set_yscale("symlog", linthresh=ylinthresh)
    ax.set_ylim(0, 1.0)
    ax1.set_ylim(0, 1.0)
    lib.plot.rotate_xticklabels(ax=ax)

In [None]:
d = (
    rotu_rabund.apply(lambda x: x + x.replace({0: np.inf}).min())
    .join(sample)
    .groupby(["subject_id", "diet_or_media"])[rotu_id]
    .mean()
    .unstack()[["EEN", "PostEEN"]]
    .dropna()
    .assign(log2_ratio=lambda x: np.log2(x["PostEEN"] / x["EEN"]))
)

fig, ax = plt.subplots()
print(d.log2_ratio.mean())
print(sp.stats.wilcoxon(d["PostEEN"], d["EEN"]))
ax.hist(d.log2_ratio, bins=20)

fig, ax = plt.subplots(figsize=(3, 8))
max_c_value = np.abs(d.log2_ratio).max()
for subject_id, (een_rabund, post_rabund, log2_ratio, c) in d.assign(
    c=lambda x: ((x.log2_ratio / max_c_value) + 1) / 2
).iterrows():
    ax.plot([0, 1], [een_rabund, post_rabund], c=mpl.cm.coolwarm(c), lw=4)
ax.set_yscale("log")
ax.set_xticks([0, 1])
ax.set_xticklabels(["EEN", "PostEEN"])
ax.set_xlim(-0.1, 1.1)

In [None]:
for subject_id, sample_type in product(["A", "B", "H"], ["Fermenter", "mouse"]):
    all_sample_list = (
        sample[lambda x: (x.sample_type == sample_type) & (x.subject_id == subject_id)]
        .index.to_series()
        .pipe(list)
    )
    subject_sample_order = (
        sample[
            lambda x: (x.sample_type == sample_type)
            & (x.subject_id == subject_id)
            & (x.index.isin(sf_fit.sample.values))
        ]
        .sort_values(
            [
                "sample_type",
                "diet_or_media",
                "mouse_genotype",
                "source_samples",
                "status_mouse_inflamed",
            ]
        )
        .index.to_series()
        .pipe(list)
    )
    # Cull low abundance strains for each subject

    strain_frac_sample_list = list(set(all_sample_list) & set(sf_fit.sample.values))
    if len(strain_frac_sample_list) == 0:
        print(f"No strain analysis for {subject_id}.")
        comm = []
        _strain_order = []
    else:
        w = (
            sf_fit.sel(sample=strain_frac_sample_list)
            .drop_low_abundance_strains(drop_strains_thresh)
            .rename_coords(strain=str)
        )
        _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
        comm = w.community.to_pandas()

    d0 = (
        sample.loc[
            subject_sample_order,
            [
                "sample_type",
                "diet_or_media",
                "mouse_genotype",
                "source_samples",
                "status_mouse_inflamed",
            ],
        ]
        .assign(
            rotu_rabund=rotu_rabund[rotu_id],
        )
        .join(comm)
        .assign(xpos=lambda x: np.arange(len(x.index)))
    )

    fig, ax = plt.subplots(
        figsize=(0.4 * len(strain_frac_sample_list), 4),
    )

    ax.plot(
        "rotu_rabund",
        data=d0,
        marker="o",
        linestyle="",
        label="__nolegend__",
        color="k",
        markersize=5,
    )

    # Plot stacked barplot
    ax1 = ax.twinx()
    top_last = 0
    for strain, d1 in d0[_strain_order].T.iterrows():
        ax1.bar(
            x=d0.loc[d1.index].xpos,
            height=d1,
            width=1,
            bottom=top_last,
            alpha=1.0,
            color=strain_palette[strain],
            edgecolor="k",
            lw=1,
            label=strain,
        )
        top_last += d1
    ax1.set_yticks([])

    # Put strains behind points:
    ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
    ax.patch.set_visible(False)  # hide the 'canvas'
    ax1.patch.set_visible(True)  # show the 'canvas'

    # # Start and end of EEN
    # ax.axvline(0, lw=1, linestyle="--", color="k")
    # ax.axvline(
    #     trnsfm_x(subject.een_start_date_relative_een_end.loc[subject_id]),
    #     lw=1,
    #     linestyle="--",
    #     color="k",
    # )
    ax1.legend(bbox_to_anchor=(1, 1), ncols=2)

    ax.set_title(subject_id)

    # xtick_pos = np.array([-100, -50, -10, 0, 10, 50, 100, 300])
    # ax.set_xticks(trnsfm_x(xtick_pos))
    # ax.set_xticklabels(xtick_pos)
    ax.set_yscale("symlog", linthresh=ylinthresh)
    ax.set_ylim(0, 1.0)
    ax.set_xlim(-0.5, len(strain_frac_sample_list) - 0.5)
    lib.plot.rotate_xticklabels(ax=ax)

## F. plautii (Zotu49 / 100099)

In [None]:
motu_id = "100099"
rotu_id = "Zotu49"
drop_strains_thresh = 0.5
ylinthresh = 1e-4

In [None]:
sf_fit = (
    sf.data.World.load(
        f"data/group/een/species/sp-{motu_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts48-s85-seed0.world.nc"
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={"CF_11": "CF_15", "CF_15": "CF_11"})
    .drop_low_abundance_strains(0.01)
    .rename_coords(strain=str)
)
mgtp_error = sf.evaluation.metagenotype_error2(sf_fit, discretized=False)[1]
entrp_error = sf.evaluation.metagenotype_entropy_error(
    sf_fit, discretized=False, p=1, montecarlo_draws=10
)[1]
comm_entrp = sf_fit.community.entropy().to_series()
high_mgtp_error = mgtp_error >= 0.1
high_entrp_error = entrp_error >= 0.2
high_comm_entrp = comm_entrp >= 1.5

# Genotype similarity ordered palette:
strain_linkage = sf_fit.genotype.linkage(optimal_ordering=True)
strain_order = list(
    linkage_order(
        strain_linkage,
        sf_fit.strain.values,
    )
)
strain_order.remove("-1")  # Drop "other" strain.
strain_palette = lib.plot.construct_ordered_palette(
    strain_order,
    cm="rainbow",
)

sf.plot.plot_community(
    sf_fit,
    scalex=0.4,
    scaley=0.6,
    row_linkage_func=lambda w: strain_linkage,
    row_colors=sf_fit.strain.to_series().map(strain_palette),
)

In [None]:
trnsfm_x = lambda x: np.sign(x) * np.sqrt(np.abs(x))
bar_width = 1.0

subject_list = subject_order  # [:3]

fig, axs = plt.subplots(
    nrows=len(subject_list),
    figsize=(10, 4 * len(subject_list)),
    squeeze=False,
    sharex=True,
    sharey=True,
)
for subject_id, ax in zip(subject_list, axs.flatten()):
    ax.set_title(subject_id)
    all_sample_list = (
        sample[lambda x: x.subject_id == subject_id].index.to_series().pipe(list)
    )
    subject_sample_order = (
        sample[lambda x: (x.subject_id == subject_id) & x.sample_type.isin(["human"])]
        .sort_values(["collection_date_relative_een_end"])
        .index.to_series()
        .pipe(list)
    )
    # Cull low abundance strains for each subject

    strain_frac_sample_list = list(set(all_sample_list) & set(sf_fit.sample.values))
    if len(strain_frac_sample_list) == 0:
        print(f"No strain analysis for {subject_id}.")
        comm = []
        _strain_order = []
    else:
        w = (
            sf_fit.sel(sample=strain_frac_sample_list)
            .drop_low_abundance_strains(drop_strains_thresh)
            .rename_coords(strain=str)
        )
        _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
        comm = w.community.to_pandas()

    d0 = (
        sample.loc[subject_sample_order, ["collection_date_relative_een_end"]]
        .assign(
            rotu_rabund=rotu_rabund[rotu_id],
            trnsfm_collection_date_relative_een_end=lambda x: x.collection_date_relative_een_end.apply(
                trnsfm_x
            ),
        )
        .join(comm)
    )

    ax.plot(
        "trnsfm_collection_date_relative_een_end",
        "rotu_rabund",
        data=d0,
        marker="o",
        label="__nolegend__",
        color="k",
        markersize=5,
    )

    # Plot stacked barplot
    ax1 = ax.twinx()
    top_last = 0
    for strain, d1 in d0[_strain_order].T.iterrows():
        ax1.bar(
            x=d0.loc[d1.index].trnsfm_collection_date_relative_een_end,
            height=d1,
            bottom=top_last,
            width=bar_width,
            alpha=1.0,
            color=strain_palette[strain],
            edgecolor="k",
            lw=1,
            label=strain,
        )
        top_last += d1
    ax1.set_yticks([])

    # Put strains behind points:
    ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
    ax.patch.set_visible(False)  # hide the 'canvas'
    ax1.patch.set_visible(True)  # show the 'canvas'

    # Start and end of EEN
    ax.axvline(0, lw=1, linestyle="--", color="k")
    ax.axvline(
        trnsfm_x(subject.een_start_date_relative_een_end.loc[subject_id]),
        lw=1,
        linestyle="--",
        color="k",
    )
    ax1.legend(bbox_to_anchor=(1, 1), ncols=1)
    xtick_pos = np.array([-100, -50, -10, 0, 10, 50, 100, 300])
    ax.set_xticks(trnsfm_x(xtick_pos))
    ax.set_xticklabels(xtick_pos)
    ax.set_yscale("symlog", linthresh=ylinthresh)
    ax.set_ylim(0, 1.0)
    ax1.set_ylim(0, 1.0)
    lib.plot.rotate_xticklabels(ax=ax)

In [None]:
d = (
    rotu_rabund.apply(lambda x: x + x.replace({0: np.inf}).min())
    .join(sample)
    .groupby(["subject_id", "diet_or_media"])[rotu_id]
    .mean()
    .unstack()[["EEN", "PostEEN"]]
    .dropna()
    .assign(log2_ratio=lambda x: np.log2(x["PostEEN"] / x["EEN"]))
)

fig, ax = plt.subplots()
print(d.log2_ratio.mean())
print(sp.stats.wilcoxon(d["PostEEN"], d["EEN"]))
ax.hist(d.log2_ratio, bins=20)

fig, ax = plt.subplots(figsize=(3, 8))
max_c_value = np.abs(d.log2_ratio).max()
for subject_id, (een_rabund, post_rabund, log2_ratio, c) in d.assign(
    c=lambda x: ((x.log2_ratio / max_c_value) + 1) / 2
).iterrows():
    ax.plot([0, 1], [een_rabund, post_rabund], c=mpl.cm.coolwarm(c), lw=4)
ax.set_yscale("log")
ax.set_xticks([0, 1])
ax.set_xticklabels(["EEN", "PostEEN"])
ax.set_xlim(-0.1, 1.1)

In [None]:
for subject_id, sample_type in product(["A", "B", "H"], ["Fermenter", "mouse"]):
    all_sample_list = (
        sample[lambda x: (x.sample_type == sample_type) & (x.subject_id == subject_id)]
        .index.to_series()
        .pipe(list)
    )
    subject_sample_order = (
        sample[
            lambda x: (x.sample_type == sample_type)
            & (x.subject_id == subject_id)
            & (x.index.isin(sf_fit.sample.values))
        ]
        .sort_values(
            [
                "sample_type",
                "diet_or_media",
                "mouse_genotype",
                "source_samples",
                "status_mouse_inflamed",
            ]
        )
        .index.to_series()
        .pipe(list)
    )
    # Cull low abundance strains for each subject

    strain_frac_sample_list = list(set(all_sample_list) & set(sf_fit.sample.values))
    if len(strain_frac_sample_list) == 0:
        print(f"No strain analysis for {subject_id}.")
        comm = []
        _strain_order = []
    else:
        w = (
            sf_fit.sel(sample=strain_frac_sample_list)
            .drop_low_abundance_strains(drop_strains_thresh)
            .rename_coords(strain=str)
        )
        _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
        comm = w.community.to_pandas()

    d0 = (
        sample.loc[
            subject_sample_order,
            [
                "sample_type",
                "diet_or_media",
                "mouse_genotype",
                "source_samples",
                "status_mouse_inflamed",
            ],
        ]
        .assign(
            rotu_rabund=rotu_rabund[rotu_id],
        )
        .join(comm)
        .assign(xpos=lambda x: np.arange(len(x.index)))
    )

    fig, ax = plt.subplots(
        figsize=(0.4 * len(strain_frac_sample_list), 4),
    )

    ax.plot(
        "rotu_rabund",
        data=d0,
        marker="o",
        linestyle="",
        label="__nolegend__",
        color="k",
        markersize=5,
    )

    # Plot stacked barplot
    ax1 = ax.twinx()
    top_last = 0
    for strain, d1 in d0[_strain_order].T.iterrows():
        ax1.bar(
            x=d0.loc[d1.index].xpos,
            height=d1,
            width=1,
            bottom=top_last,
            alpha=1.0,
            color=strain_palette[strain],
            edgecolor="k",
            lw=1,
            label=strain,
        )
        top_last += d1
    ax1.set_yticks([])

    # Put strains behind points:
    ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
    ax.patch.set_visible(False)  # hide the 'canvas'
    ax1.patch.set_visible(True)  # show the 'canvas'

    # # Start and end of EEN
    # ax.axvline(0, lw=1, linestyle="--", color="k")
    # ax.axvline(
    #     trnsfm_x(subject.een_start_date_relative_een_end.loc[subject_id]),
    #     lw=1,
    #     linestyle="--",
    #     color="k",
    # )
    ax1.legend(bbox_to_anchor=(1, 1), ncols=2)

    ax.set_title(subject_id)

    # xtick_pos = np.array([-100, -50, -10, 0, 10, 50, 100, 300])
    # ax.set_xticks(trnsfm_x(xtick_pos))
    # ax.set_xticklabels(xtick_pos)
    ax.set_yscale("symlog", linthresh=ylinthresh)
    ax.set_ylim(0, 1.0)
    ax.set_xlim(-0.5, len(strain_frac_sample_list) - 0.5)
    lib.plot.rotate_xticklabels(ax=ax)

## Bacteroides dorei (Zotu1 / 102478)

In [None]:
motu_id = "102478"
rotu_id = "Zotu1"
drop_strains_thresh = 0.5
ylinthresh = 1e-4

In [None]:
sf_fit = (
    sf.data.World.load(
        f"data/group/een/species/sp-{motu_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts48-s85-seed0.world.nc"
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={"CF_11": "CF_15", "CF_15": "CF_11"})
    .drop_low_abundance_strains(0.01)
    .rename_coords(strain=str)
)
mgtp_error = sf.evaluation.metagenotype_error2(sf_fit, discretized=False)[1]
entrp_error = sf.evaluation.metagenotype_entropy_error(
    sf_fit, discretized=False, p=1, montecarlo_draws=10
)[1]
comm_entrp = sf_fit.community.entropy().to_series()
high_mgtp_error = mgtp_error >= 0.1
high_entrp_error = entrp_error >= 0.2
high_comm_entrp = comm_entrp >= 1.5

# Genotype similarity ordered palette:
strain_linkage = sf_fit.genotype.linkage(optimal_ordering=True)
strain_order = list(
    linkage_order(
        strain_linkage,
        sf_fit.strain.values,
    )
)
strain_order.remove("-1")  # Drop "other" strain.
strain_palette = lib.plot.construct_ordered_palette(
    strain_order,
    cm="rainbow",
)

sf.plot.plot_community(
    sf_fit,
    scalex=0.4,
    scaley=0.6,
    row_linkage_func=lambda w: strain_linkage,
    row_colors=sf_fit.strain.to_series().map(strain_palette),
)

In [None]:
trnsfm_x = lambda x: np.sign(x) * np.sqrt(np.abs(x))
bar_width = 1.0

subject_list = subject_order  # [:3]

fig, axs = plt.subplots(
    nrows=len(subject_list),
    figsize=(10, 4 * len(subject_list)),
    squeeze=False,
    sharex=True,
    sharey=True,
)
for subject_id, ax in zip(subject_list, axs.flatten()):
    ax.set_title(subject_id)
    all_sample_list = (
        sample[lambda x: x.subject_id == subject_id].index.to_series().pipe(list)
    )
    subject_sample_order = (
        sample[lambda x: (x.subject_id == subject_id) & x.sample_type.isin(["human"])]
        .sort_values(["collection_date_relative_een_end"])
        .index.to_series()
        .pipe(list)
    )
    # Cull low abundance strains for each subject

    strain_frac_sample_list = list(set(all_sample_list) & set(sf_fit.sample.values))
    if len(strain_frac_sample_list) == 0:
        print(f"No strain analysis for {subject_id}.")
        comm = []
        _strain_order = []
    else:
        w = (
            sf_fit.sel(sample=strain_frac_sample_list)
            .drop_low_abundance_strains(drop_strains_thresh)
            .rename_coords(strain=str)
        )
        _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
        comm = w.community.to_pandas()

    d0 = (
        sample.loc[subject_sample_order, ["collection_date_relative_een_end"]]
        .assign(
            rotu_rabund=rotu_rabund[rotu_id],
            trnsfm_collection_date_relative_een_end=lambda x: x.collection_date_relative_een_end.apply(
                trnsfm_x
            ),
        )
        .join(comm)
    )

    ax.plot(
        "trnsfm_collection_date_relative_een_end",
        "rotu_rabund",
        data=d0,
        marker="o",
        label="__nolegend__",
        color="k",
        markersize=5,
    )

    # Plot stacked barplot
    ax1 = ax.twinx()
    top_last = 0
    for strain, d1 in d0[_strain_order].T.iterrows():
        ax1.bar(
            x=d0.loc[d1.index].trnsfm_collection_date_relative_een_end,
            height=d1,
            bottom=top_last,
            width=bar_width,
            alpha=1.0,
            color=strain_palette[strain],
            edgecolor="k",
            lw=1,
            label=strain,
        )
        top_last += d1
    ax1.set_yticks([])

    # Put strains behind points:
    ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
    ax.patch.set_visible(False)  # hide the 'canvas'
    ax1.patch.set_visible(True)  # show the 'canvas'

    # Start and end of EEN
    ax.axvline(0, lw=1, linestyle="--", color="k")
    ax.axvline(
        trnsfm_x(subject.een_start_date_relative_een_end.loc[subject_id]),
        lw=1,
        linestyle="--",
        color="k",
    )
    ax1.legend(bbox_to_anchor=(1, 1), ncols=1)
    xtick_pos = np.array([-100, -50, -10, 0, 10, 50, 100, 300])
    ax.set_xticks(trnsfm_x(xtick_pos))
    ax.set_xticklabels(xtick_pos)
    ax.set_yscale("symlog", linthresh=ylinthresh)
    ax.set_ylim(0, 1.0)
    ax1.set_ylim(0, 1.0)
    lib.plot.rotate_xticklabels(ax=ax)

In [None]:
d = (
    rotu_rabund.apply(lambda x: x + x.replace({0: np.inf}).min())
    .join(sample)
    .groupby(["subject_id", "diet_or_media"])[rotu_id]
    .mean()
    .unstack()[["EEN", "PostEEN"]]
    .dropna()
    .assign(log2_ratio=lambda x: np.log2(x["PostEEN"] / x["EEN"]))
)

fig, ax = plt.subplots()
print(d.log2_ratio.mean())
print(sp.stats.wilcoxon(d["PostEEN"], d["EEN"]))
ax.hist(d.log2_ratio, bins=20)

fig, ax = plt.subplots(figsize=(3, 8))
max_c_value = np.abs(d.log2_ratio).max()
for subject_id, (een_rabund, post_rabund, log2_ratio, c) in d.assign(
    c=lambda x: ((x.log2_ratio / max_c_value) + 1) / 2
).iterrows():
    ax.plot([0, 1], [een_rabund, post_rabund], c=mpl.cm.coolwarm(c), lw=4)
ax.set_yscale("log")
ax.set_xticks([0, 1])
ax.set_xticklabels(["EEN", "PostEEN"])
ax.set_xlim(-0.1, 1.1)

In [None]:
for subject_id, sample_type in product(["A", "B", "H"], ["Fermenter", "mouse"]):
    all_sample_list = (
        sample[lambda x: (x.sample_type == sample_type) & (x.subject_id == subject_id)]
        .index.to_series()
        .pipe(list)
    )
    subject_sample_order = (
        sample[
            lambda x: (x.sample_type == sample_type)
            & (x.subject_id == subject_id)
            & (x.index.isin(sf_fit.sample.values))
        ]
        .sort_values(
            [
                "sample_type",
                "diet_or_media",
                "mouse_genotype",
                "source_samples",
                "status_mouse_inflamed",
            ]
        )
        .index.to_series()
        .pipe(list)
    )
    # Cull low abundance strains for each subject

    strain_frac_sample_list = list(set(all_sample_list) & set(sf_fit.sample.values))
    if len(strain_frac_sample_list) == 0:
        print(f"No strain analysis for {subject_id}.")
        comm = []
        _strain_order = []
    else:
        w = (
            sf_fit.sel(sample=strain_frac_sample_list)
            .drop_low_abundance_strains(drop_strains_thresh)
            .rename_coords(strain=str)
        )
        _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
        comm = w.community.to_pandas()

    d0 = (
        sample.loc[
            subject_sample_order,
            [
                "sample_type",
                "diet_or_media",
                "mouse_genotype",
                "source_samples",
                "status_mouse_inflamed",
            ],
        ]
        .assign(
            rotu_rabund=rotu_rabund[rotu_id],
        )
        .join(comm)
        .assign(xpos=lambda x: np.arange(len(x.index)))
    )

    fig, ax = plt.subplots(
        figsize=(0.4 * len(strain_frac_sample_list), 4),
    )

    ax.plot(
        "rotu_rabund",
        data=d0,
        marker="o",
        linestyle="",
        label="__nolegend__",
        color="k",
        markersize=5,
    )

    # Plot stacked barplot
    ax1 = ax.twinx()
    top_last = 0
    for strain, d1 in d0[_strain_order].T.iterrows():
        ax1.bar(
            x=d0.loc[d1.index].xpos,
            height=d1,
            width=1,
            bottom=top_last,
            alpha=1.0,
            color=strain_palette[strain],
            edgecolor="k",
            lw=1,
            label=strain,
        )
        top_last += d1
    ax1.set_yticks([])

    # Put strains behind points:
    ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
    ax.patch.set_visible(False)  # hide the 'canvas'
    ax1.patch.set_visible(True)  # show the 'canvas'

    # # Start and end of EEN
    # ax.axvline(0, lw=1, linestyle="--", color="k")
    # ax.axvline(
    #     trnsfm_x(subject.een_start_date_relative_een_end.loc[subject_id]),
    #     lw=1,
    #     linestyle="--",
    #     color="k",
    # )
    ax1.legend(bbox_to_anchor=(1, 1), ncols=2)

    ax.set_title(subject_id)

    # xtick_pos = np.array([-100, -50, -10, 0, 10, 50, 100, 300])
    # ax.set_xticks(trnsfm_x(xtick_pos))
    # ax.set_xticklabels(xtick_pos)
    ax.set_yscale("symlog", linthresh=ylinthresh)
    ax.set_ylim(0, 1.0)
    ax.set_xlim(-0.5, len(strain_frac_sample_list) - 0.5)
    lib.plot.rotate_xticklabels(ax=ax)

## Bacteroides uniformis (Zotu6 / 101346)

In [None]:
motu_id = "101346"
rotu_id = "Zotu6"
drop_strains_thresh = 0.5
ylinthresh = 1e-4

In [None]:
sf_fit = (
    sf.data.World.load(
        f"data/group/een/species/sp-{motu_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts48-s85-seed0.world.nc"
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={"CF_11": "CF_15", "CF_15": "CF_11"})
    .drop_low_abundance_strains(0.01)
    .rename_coords(strain=str)
)
mgtp_error = sf.evaluation.metagenotype_error2(sf_fit, discretized=False)[1]
entrp_error = sf.evaluation.metagenotype_entropy_error(
    sf_fit, discretized=False, p=1, montecarlo_draws=10
)[1]
comm_entrp = sf_fit.community.entropy().to_series()
high_mgtp_error = mgtp_error >= 0.1
high_entrp_error = entrp_error >= 0.2
high_comm_entrp = comm_entrp >= 1.5

# Genotype similarity ordered palette:
strain_linkage = sf_fit.genotype.linkage(optimal_ordering=True)
strain_order = list(
    linkage_order(
        strain_linkage,
        sf_fit.strain.values,
    )
)
strain_order.remove("-1")  # Drop "other" strain.
strain_palette = lib.plot.construct_ordered_palette(
    strain_order,
    cm="rainbow",
)

sf.plot.plot_community(
    sf_fit,
    scalex=0.4,
    scaley=0.6,
    row_linkage_func=lambda w: strain_linkage,
    row_colors=sf_fit.strain.to_series().map(strain_palette),
)

In [None]:
trnsfm_x = lambda x: np.sign(x) * np.sqrt(np.abs(x))
bar_width = 1.0

subject_list = subject_order  # [:3]

fig, axs = plt.subplots(
    nrows=len(subject_list),
    figsize=(10, 4 * len(subject_list)),
    squeeze=False,
    sharex=True,
    sharey=True,
)
for subject_id, ax in zip(subject_list, axs.flatten()):
    ax.set_title(subject_id)
    all_sample_list = (
        sample[lambda x: x.subject_id == subject_id].index.to_series().pipe(list)
    )
    subject_sample_order = (
        sample[lambda x: (x.subject_id == subject_id) & x.sample_type.isin(["human"])]
        .sort_values(["collection_date_relative_een_end"])
        .index.to_series()
        .pipe(list)
    )
    # Cull low abundance strains for each subject

    strain_frac_sample_list = list(set(all_sample_list) & set(sf_fit.sample.values))
    if len(strain_frac_sample_list) == 0:
        print(f"No strain analysis for {subject_id}.")
        comm = []
        _strain_order = []
    else:
        w = (
            sf_fit.sel(sample=strain_frac_sample_list)
            .drop_low_abundance_strains(drop_strains_thresh)
            .rename_coords(strain=str)
        )
        _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
        comm = w.community.to_pandas()

    d0 = (
        sample.loc[subject_sample_order, ["collection_date_relative_een_end"]]
        .assign(
            rotu_rabund=rotu_rabund[rotu_id],
            trnsfm_collection_date_relative_een_end=lambda x: x.collection_date_relative_een_end.apply(
                trnsfm_x
            ),
        )
        .join(comm)
    )

    ax.plot(
        "trnsfm_collection_date_relative_een_end",
        "rotu_rabund",
        data=d0,
        marker="o",
        label="__nolegend__",
        color="k",
        markersize=5,
    )

    # Plot stacked barplot
    ax1 = ax.twinx()
    top_last = 0
    for strain, d1 in d0[_strain_order].T.iterrows():
        ax1.bar(
            x=d0.loc[d1.index].trnsfm_collection_date_relative_een_end,
            height=d1,
            bottom=top_last,
            width=bar_width,
            alpha=1.0,
            color=strain_palette[strain],
            edgecolor="k",
            lw=1,
            label=strain,
        )
        top_last += d1
    ax1.set_yticks([])

    # Put strains behind points:
    ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
    ax.patch.set_visible(False)  # hide the 'canvas'
    ax1.patch.set_visible(True)  # show the 'canvas'

    # Start and end of EEN
    ax.axvline(0, lw=1, linestyle="--", color="k")
    ax.axvline(
        trnsfm_x(subject.een_start_date_relative_een_end.loc[subject_id]),
        lw=1,
        linestyle="--",
        color="k",
    )
    ax1.legend(bbox_to_anchor=(1, 1), ncols=1)
    xtick_pos = np.array([-100, -50, -10, 0, 10, 50, 100, 300])
    ax.set_xticks(trnsfm_x(xtick_pos))
    ax.set_xticklabels(xtick_pos)
    ax.set_yscale("symlog", linthresh=ylinthresh)
    ax.set_ylim(0, 1.0)
    ax1.set_ylim(0, 1.0)
    lib.plot.rotate_xticklabels(ax=ax)

In [None]:
trnsfm_x = lambda x: np.sign(x) * np.sqrt(np.abs(x))
bar_width = 1.0

subject_list = subject_order[:3]

fig, axs = plt.subplots(
    nrows=len(subject_list),
    figsize=(10, 4 * len(subject_list)),
    squeeze=False,
    sharex=True,
    sharey=True,
)
for subject_id, ax in zip(subject_list, axs.flatten()):
    ax.set_title(subject_id)
    all_sample_list = (
        sample[lambda x: x.subject_id == subject_id].index.to_series().pipe(list)
    )
    subject_sample_order = (
        sample[lambda x: (x.subject_id == subject_id) & x.sample_type.isin(["human"])]
        .sort_values(["collection_date_relative_een_end"])
        .index.to_series()
        .pipe(list)
    )
    # Cull low abundance strains for each subject

    strain_frac_sample_list = list(set(all_sample_list) & set(sf_fit.sample.values))
    if len(strain_frac_sample_list) == 0:
        print(f"No strain analysis for {subject_id}.")
        comm = []
        _strain_order = []
    else:
        w = (
            sf_fit.sel(sample=strain_frac_sample_list)
            .drop_low_abundance_strains(drop_strains_thresh)
            .rename_coords(strain=str)
        )
        _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
        comm = w.community.to_pandas()

    d0 = (
        sample.loc[subject_sample_order, ["collection_date_relative_een_end"]]
        .assign(
            rotu_rabund=rotu_rabund[rotu_id],
            trnsfm_collection_date_relative_een_end=lambda x: x.collection_date_relative_een_end.apply(
                trnsfm_x
            ),
        )
        .join(comm)
    )

    ax.plot(
        "trnsfm_collection_date_relative_een_end",
        "rotu_rabund",
        data=d0,
        marker="o",
        label="__nolegend__",
        color="k",
        markersize=5,
    )

    # Plot stacked barplot
    ax1 = ax.twinx()
    top_last = 0
    for strain, d1 in d0[_strain_order].T.iterrows():
        ax1.bar(
            x=d0.loc[d1.index].trnsfm_collection_date_relative_een_end,
            height=d1,
            bottom=top_last,
            width=bar_width,
            alpha=1.0,
            color=strain_palette[strain],
            edgecolor="k",
            lw=1,
            label=strain,
        )
        top_last += d1
    ax1.set_yticks([])

    # Put strains behind points:
    ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
    ax.patch.set_visible(False)  # hide the 'canvas'
    ax1.patch.set_visible(True)  # show the 'canvas'

    # Start and end of EEN
    ax.axvline(0, lw=1, linestyle="--", color="k")
    ax.axvline(
        trnsfm_x(subject.een_start_date_relative_een_end.loc[subject_id]),
        lw=1,
        linestyle="--",
        color="k",
    )
    ax1.legend(bbox_to_anchor=(1, 1), ncols=1)
    xtick_pos = np.array([-100, -50, -10, 0, 10, 50, 100, 300])
    ax.set_xticks(trnsfm_x(xtick_pos))
    ax.set_xticklabels(xtick_pos)
    ax.set_yscale("symlog", linthresh=ylinthresh)
    ax.set_ylim(0, 1.0)
    ax1.set_ylim(0, 1.0)
    lib.plot.rotate_xticklabels(ax=ax)

In [None]:
d = (
    rotu_rabund.apply(lambda x: x + x.replace({0: np.inf}).min())
    .join(sample)
    .groupby(["subject_id", "diet_or_media"])[rotu_id]
    .mean()
    .unstack()[["EEN", "PostEEN"]]
    .dropna()
    .assign(log2_ratio=lambda x: np.log2(x["PostEEN"] / x["EEN"]))
)

fig, ax = plt.subplots()
print(d.log2_ratio.mean())
print(sp.stats.wilcoxon(d["PostEEN"], d["EEN"]))
ax.hist(d.log2_ratio, bins=20)

fig, ax = plt.subplots(figsize=(3, 8))
max_c_value = np.abs(d.log2_ratio).max()
for subject_id, (een_rabund, post_rabund, log2_ratio, c) in d.assign(
    c=lambda x: ((x.log2_ratio / max_c_value) + 1) / 2
).iterrows():
    ax.plot([0, 1], [een_rabund, post_rabund], c=mpl.cm.coolwarm(c), lw=4)
ax.set_yscale("log")
ax.set_xticks([0, 1])
ax.set_xticklabels(["EEN", "PostEEN"])
ax.set_xlim(-0.1, 1.1)

In [None]:
for subject_id, sample_type in product(["A", "B", "H"], ["Fermenter", "mouse"]):
    all_sample_list = (
        sample[lambda x: (x.sample_type == sample_type) & (x.subject_id == subject_id)]
        .index.to_series()
        .pipe(list)
    )
    subject_sample_order = (
        sample[
            lambda x: (x.sample_type == sample_type)
            & (x.subject_id == subject_id)
            & (x.index.isin(sf_fit.sample.values))
        ]
        .sort_values(
            [
                "sample_type",
                "diet_or_media",
                "mouse_genotype",
                "source_samples",
                "status_mouse_inflamed",
            ]
        )
        .index.to_series()
        .pipe(list)
    )
    # Cull low abundance strains for each subject

    strain_frac_sample_list = list(set(all_sample_list) & set(sf_fit.sample.values))
    if len(strain_frac_sample_list) == 0:
        print(f"No strain analysis for {subject_id}.")
        comm = []
        _strain_order = []
    else:
        w = (
            sf_fit.sel(sample=strain_frac_sample_list)
            .drop_low_abundance_strains(drop_strains_thresh)
            .rename_coords(strain=str)
        )
        _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
        comm = w.community.to_pandas()

    d0 = (
        sample.loc[
            subject_sample_order,
            [
                "sample_type",
                "diet_or_media",
                "mouse_genotype",
                "source_samples",
                "status_mouse_inflamed",
            ],
        ]
        .assign(
            rotu_rabund=rotu_rabund[rotu_id],
        )
        .join(comm)
        .assign(xpos=lambda x: np.arange(len(x.index)))
    )

    fig, ax = plt.subplots(
        figsize=(0.4 * len(strain_frac_sample_list), 4),
    )

    ax.plot(
        "rotu_rabund",
        data=d0,
        marker="o",
        linestyle="",
        label="__nolegend__",
        color="k",
        markersize=5,
    )

    # Plot stacked barplot
    ax1 = ax.twinx()
    top_last = 0
    for strain, d1 in d0[_strain_order].T.iterrows():
        ax1.bar(
            x=d0.loc[d1.index].xpos,
            height=d1,
            width=1,
            bottom=top_last,
            alpha=1.0,
            color=strain_palette[strain],
            edgecolor="k",
            lw=1,
            label=strain,
        )
        top_last += d1
    ax1.set_yticks([])

    # Put strains behind points:
    ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
    ax.patch.set_visible(False)  # hide the 'canvas'
    ax1.patch.set_visible(True)  # show the 'canvas'

    # # Start and end of EEN
    # ax.axvline(0, lw=1, linestyle="--", color="k")
    # ax.axvline(
    #     trnsfm_x(subject.een_start_date_relative_een_end.loc[subject_id]),
    #     lw=1,
    #     linestyle="--",
    #     color="k",
    # )
    ax1.legend(bbox_to_anchor=(1, 1), ncols=2)

    ax.set_title(subject_id)

    # xtick_pos = np.array([-100, -50, -10, 0, 10, 50, 100, 300])
    # ax.set_xticks(trnsfm_x(xtick_pos))
    # ax.set_xticklabels(xtick_pos)
    ax.set_yscale("symlog", linthresh=ylinthresh)
    ax.set_ylim(0, 1.0)
    ax.set_xlim(-0.5, len(strain_frac_sample_list) - 0.5)
    lib.plot.rotate_xticklabels(ax=ax)

## s__Bacteroides fragilis (Zotu12 / 101337)

In [None]:
motu_id = "101337"
rotu_id = "Zotu12"
drop_strains_thresh = 0.5
ylinthresh = 1e-4

In [None]:
sf_fit = (
    sf.data.World.load(
        f"data/group/een/species/sp-{motu_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts48-s85-seed0.world.nc"
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={"CF_11": "CF_15", "CF_15": "CF_11"})
    .drop_low_abundance_strains(0.01)
    .rename_coords(strain=str)
)
mgtp_error = sf.evaluation.metagenotype_error2(sf_fit, discretized=False)[1]
entrp_error = sf.evaluation.metagenotype_entropy_error(
    sf_fit, discretized=False, p=1, montecarlo_draws=10
)[1]
comm_entrp = sf_fit.community.entropy().to_series()
high_mgtp_error = mgtp_error >= 0.1
high_entrp_error = entrp_error >= 0.2
high_comm_entrp = comm_entrp >= 1.5

# Genotype similarity ordered palette:
strain_linkage = sf_fit.genotype.linkage(optimal_ordering=True)
strain_order = list(
    linkage_order(
        strain_linkage,
        sf_fit.strain.values,
    )
)
strain_order.remove("-1")  # Drop "other" strain.
strain_palette = lib.plot.construct_ordered_palette(
    strain_order,
    cm="rainbow",
)

sf.plot.plot_community(
    sf_fit,
    scalex=0.4,
    scaley=0.6,
    row_linkage_func=lambda w: strain_linkage,
    row_colors=sf_fit.strain.to_series().map(strain_palette),
)

In [None]:
trnsfm_x = lambda x: np.sign(x) * np.sqrt(np.abs(x))
bar_width = 1.0

subject_list = subject_order  # [:3]

fig, axs = plt.subplots(
    nrows=len(subject_list),
    figsize=(10, 4 * len(subject_list)),
    squeeze=False,
    sharex=True,
    sharey=True,
)
for subject_id, ax in zip(subject_list, axs.flatten()):
    ax.set_title(subject_id)
    all_sample_list = (
        sample[lambda x: x.subject_id == subject_id].index.to_series().pipe(list)
    )
    subject_sample_order = (
        sample[lambda x: (x.subject_id == subject_id) & x.sample_type.isin(["human"])]
        .sort_values(["collection_date_relative_een_end"])
        .index.to_series()
        .pipe(list)
    )
    # Cull low abundance strains for each subject

    strain_frac_sample_list = list(set(all_sample_list) & set(sf_fit.sample.values))
    if len(strain_frac_sample_list) == 0:
        print(f"No strain analysis for {subject_id}.")
        comm = []
        _strain_order = []
    else:
        w = (
            sf_fit.sel(sample=strain_frac_sample_list)
            .drop_low_abundance_strains(drop_strains_thresh)
            .rename_coords(strain=str)
        )
        _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
        comm = w.community.to_pandas()

    d0 = (
        sample.loc[subject_sample_order, ["collection_date_relative_een_end"]]
        .assign(
            rotu_rabund=rotu_rabund[rotu_id],
            trnsfm_collection_date_relative_een_end=lambda x: x.collection_date_relative_een_end.apply(
                trnsfm_x
            ),
        )
        .join(comm)
    )

    ax.plot(
        "trnsfm_collection_date_relative_een_end",
        "rotu_rabund",
        data=d0,
        marker="o",
        label="__nolegend__",
        color="k",
        markersize=5,
    )

    # Plot stacked barplot
    ax1 = ax.twinx()
    top_last = 0
    for strain, d1 in d0[_strain_order].T.iterrows():
        ax1.bar(
            x=d0.loc[d1.index].trnsfm_collection_date_relative_een_end,
            height=d1,
            bottom=top_last,
            width=bar_width,
            alpha=1.0,
            color=strain_palette[strain],
            edgecolor="k",
            lw=1,
            label=strain,
        )
        top_last += d1
    ax1.set_yticks([])

    # Put strains behind points:
    ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
    ax.patch.set_visible(False)  # hide the 'canvas'
    ax1.patch.set_visible(True)  # show the 'canvas'

    # Start and end of EEN
    ax.axvline(0, lw=1, linestyle="--", color="k")
    ax.axvline(
        trnsfm_x(subject.een_start_date_relative_een_end.loc[subject_id]),
        lw=1,
        linestyle="--",
        color="k",
    )
    ax1.legend(bbox_to_anchor=(1, 1), ncols=1)
    xtick_pos = np.array([-100, -50, -10, 0, 10, 50, 100, 300])
    ax.set_xticks(trnsfm_x(xtick_pos))
    ax.set_xticklabels(xtick_pos)
    ax.set_yscale("symlog", linthresh=ylinthresh)
    ax.set_ylim(0, 1.0)
    ax1.set_ylim(0, 1.0)
    lib.plot.rotate_xticklabels(ax=ax)

In [None]:
d = (
    rotu_rabund.apply(lambda x: x + x.replace({0: np.inf}).min())
    .join(sample)
    .groupby(["subject_id", "diet_or_media"])[rotu_id]
    .mean()
    .unstack()[["EEN", "PostEEN"]]
    .dropna()
    .assign(log2_ratio=lambda x: np.log2(x["PostEEN"] / x["EEN"]))
)

fig, ax = plt.subplots()
print(d.log2_ratio.mean())
print(sp.stats.wilcoxon(d["PostEEN"], d["EEN"]))
ax.hist(d.log2_ratio, bins=20)

fig, ax = plt.subplots(figsize=(3, 8))
max_c_value = np.abs(d.log2_ratio).max()
for subject_id, (een_rabund, post_rabund, log2_ratio, c) in d.assign(
    c=lambda x: ((x.log2_ratio / max_c_value) + 1) / 2
).iterrows():
    ax.plot([0, 1], [een_rabund, post_rabund], c=mpl.cm.coolwarm(c), lw=4)
ax.set_yscale("log")
ax.set_xticks([0, 1])
ax.set_xticklabels(["EEN", "PostEEN"])
ax.set_xlim(-0.1, 1.1)

In [None]:
for subject_id, sample_type in product(["A", "B", "H"], ["Fermenter", "mouse"]):
    all_sample_list = (
        sample[lambda x: (x.sample_type == sample_type) & (x.subject_id == subject_id)]
        .index.to_series()
        .pipe(list)
    )
    subject_sample_order = (
        sample[
            lambda x: (x.sample_type == sample_type)
            & (x.subject_id == subject_id)
            & (x.index.isin(sf_fit.sample.values))
        ]
        .sort_values(
            [
                "sample_type",
                "diet_or_media",
                "mouse_genotype",
                "source_samples",
                "status_mouse_inflamed",
            ]
        )
        .index.to_series()
        .pipe(list)
    )
    # Cull low abundance strains for each subject

    strain_frac_sample_list = list(set(all_sample_list) & set(sf_fit.sample.values))
    if len(strain_frac_sample_list) == 0:
        print(f"No strain analysis for {subject_id}.")
        comm = []
        _strain_order = []
    else:
        w = (
            sf_fit.sel(sample=strain_frac_sample_list)
            .drop_low_abundance_strains(drop_strains_thresh)
            .rename_coords(strain=str)
        )
        _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
        comm = w.community.to_pandas()

    d0 = (
        sample.loc[
            subject_sample_order,
            [
                "sample_type",
                "diet_or_media",
                "mouse_genotype",
                "source_samples",
                "status_mouse_inflamed",
            ],
        ]
        .assign(
            rotu_rabund=rotu_rabund[rotu_id],
        )
        .join(comm)
        .assign(xpos=lambda x: np.arange(len(x.index)))
    )

    fig, ax = plt.subplots(
        figsize=(0.4 * len(strain_frac_sample_list), 4),
    )

    ax.plot(
        "rotu_rabund",
        data=d0,
        marker="o",
        linestyle="",
        label="__nolegend__",
        color="k",
        markersize=5,
    )

    # Plot stacked barplot
    ax1 = ax.twinx()
    top_last = 0
    for strain, d1 in d0[_strain_order].T.iterrows():
        ax1.bar(
            x=d0.loc[d1.index].xpos,
            height=d1,
            width=1,
            bottom=top_last,
            alpha=1.0,
            color=strain_palette[strain],
            edgecolor="k",
            lw=1,
            label=strain,
        )
        top_last += d1
    ax1.set_yticks([])

    # Put strains behind points:
    ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
    ax.patch.set_visible(False)  # hide the 'canvas'
    ax1.patch.set_visible(True)  # show the 'canvas'

    # # Start and end of EEN
    # ax.axvline(0, lw=1, linestyle="--", color="k")
    # ax.axvline(
    #     trnsfm_x(subject.een_start_date_relative_een_end.loc[subject_id]),
    #     lw=1,
    #     linestyle="--",
    #     color="k",
    # )
    ax1.legend(bbox_to_anchor=(1, 1), ncols=2)

    ax.set_title(subject_id)

    # xtick_pos = np.array([-100, -50, -10, 0, 10, 50, 100, 300])
    # ax.set_xticks(trnsfm_x(xtick_pos))
    # ax.set_xticklabels(xtick_pos)
    ax.set_yscale("symlog", linthresh=ylinthresh)
    ax.set_ylim(0, 1.0)
    ax.set_xlim(-0.5, len(strain_frac_sample_list) - 0.5)
    lib.plot.rotate_xticklabels(ax=ax)