# Preamble

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

## Imports

In [None]:
import os
import subprocess
import time
from itertools import chain, product
from tempfile import mkstemp
from warnings import filterwarnings

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import sfacts as sf
import statsmodels.formula.api as smf
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.ticker as mtick

# from fastcluster import linkage
from scipy.cluster.hierarchy import linkage
from scipy.spatial.distance import pdist, squareform
from statsmodels.graphics.regressionplots import influence_plot
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
import lib.thisproject.data
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

In [None]:
def linkage_order(linkage, labels):
    return labels[sp.cluster.hierarchy.to_tree(linkage).pre_order(lambda x: x.id)]


def is_prime(n):
    if n <= 1:
        return False
    for i in range(2, int(n**0.5) + 1):
        if n % i == 0:
            return False
    return True


def iterate_primes_up_to(n, return_index=False):
    n = int(np.ceil(n))
    idx = 0
    for i in range(n):
        if is_prime(i):
            if return_index:
                yield (idx, i)
            else:
                yield i
            idx += 1


def maximally_shuffled_order(sorted_order):
    n = len(sorted_order)
    primes_list = list(iterate_primes_up_to(np.sqrt(n)))
    table = pd.DataFrame(np.arange(n), index=sorted_order, columns=["original_order"])
    for prime in primes_list:
        table[prime] = table.original_order % prime
    table.sort_values(primes_list).original_order.values
    table = table.assign(new_order=table.sort_values(primes_list).original_order.values)
    z = table.sort_values("new_order").original_order.values
    table["delta"] = [np.nan] + list(z[1:] - z[:-1])
    return table.sort_values("new_order").index.to_list()

In [None]:
def _label_experiment_sample(x):
    if x.sample_type == "human":
        label = f"[{x.name}] {x.collection_date_relative_een_end} {x.diet_or_media}"
    elif x.sample_type in ["Fermenter_inoculum"]:
        label = f"[{x.name}] {x.source_samples} inoc {x.diet_or_media}"
    elif x.sample_type in ["Fermenter"]:
        # label = f"[{x.name}] {x.source_samples} frmnt {x.diet_or_media}"
        label = f"[{x.name}] {x.source_samples} {x.diet_or_media}"
    elif x.sample_type in ["mouse"]:
        if x.status_mouse_inflamed == "Inflamed":
            # label = f"[{x.name}] {x.source_samples} 🐭 {x.mouse_genotype} {x.diet_or_media} inflam"
            label = f"[{x.name}] {x.source_samples} {x.diet_or_media} inflam"
        elif x.status_mouse_inflamed == "not_Inflamed":
            # label = f"[{x.name}] {x.source_samples} 🐭 {x.mouse_genotype} {x.diet_or_media} not_inf"
            label = f"[{x.name}] {x.source_samples} {x.diet_or_media} not_inf"
        else:
            raise ValueError(f"sample type {x.status_mouse_inflamed} not understood")
    else:
        raise ValueError(f"sample type {x.sample_type} not understood")
    return label

In [None]:
def plot_stacked_barplot(data, x_var, order, palette=None, ax=None, **kwargs):
    if ax is None:
        ax = plt.subplot()
    if palette is None:
        palette = lib.plot.construct_ordered_palette(order)

    # Bar styles
    bar_kwargs = dict(
        width=1.0,
        alpha=1.0,
        edgecolor="k",
        lw=1,
    )
    bar_kwargs.update(kwargs)

    # Plot each bar segment
    _last_top = 0
    for y_var in order:
        ax.bar(
            x=data[x_var],
            height=data[y_var],
            bottom=_last_top,
            label=y_var,
            color=palette[y_var],
            **bar_kwargs,
        )
        _last_top += data[y_var]
    ax.set_xticks(data[x_var].values)
    return ax


plot_stacked_barplot(
    pd.DataFrame(dict(t=[0, 1, 2], y1=[0.0, 0.5, 1.0], y2=[1.0, 0.5, 0.0])),
    x_var="t",
    order=["y1", "y2"],
)

In [None]:
sns.set_context("talk")
plt.rcParams["figure.dpi"] = 50

# Prepare Metadata

In [None]:
pair_type_palette = {"Transition": "plum", "EEN": "pink", "PostEEN": "lightblue"}

diet_palette = {
    "EEN": "lightgreen",
    "PostEEN": "lightblue",
    "InVitro": "plum",
    "PreEEN": "lightpink",
}

subject_order = [
    "A",
    "B",
    "H",
    "C",
    "D",
    "E",
    "F",
    "G",
    "K",
    "L",
    "M",
    "N",
    "O",
    "P",
    "Q",
    "R",
    "S",
    "T",
    "U",
]

# NOTE: Requires a dummy value because I want exactly 20 items.
subject_palette = lib.plot.construct_ordered_palette(
    subject_order + [f"dummy{i}" for i in range(20 - len(subject_order))], cm="tab20"
)
subject_palette["X"] = "black"
pair_type_order = ["EEN", "Transition", "PostEEN"]
pair_type_marker_palette = {"EEN": "s", "Transition": ">", "PostEEN": "o"}
pair_type_linestyle_palette = {"EEN": ":", "Transition": "-.", "PostEEN": "-"}

In [None]:
sample = (
    pd.read_table("meta/een-mgen/sample.tsv")
    .assign(
        label=lambda x: x[
            ["collection_date_relative_een_end", "diet_or_media", "sample_id"]
        ].apply(tuple, axis=1)
    )
    .set_index("sample_id")
    .assign(full_label=lambda d: d.apply(_label_experiment_sample, axis=1))
)
subject = pd.read_table("meta/een-mgen/subject.tsv", index_col="subject_id")

In [None]:
sample.full_label

In [None]:
def parse_taxonomy_string(taxonomy_string):
    values = taxonomy_string.split(";")
    return pd.Series(values, index=["d__", "p__", "c__", "o__", "f__", "g__", "s__"])

In [None]:
motu_taxonomy_inpath = "ref/uhgg_genomes_all_v2.tsv"

_motu_taxonomy = (
    pd.read_table(motu_taxonomy_inpath)[lambda x: x.Genome == x.Species_rep]
    .assign(species_id=lambda x: "1" + x.MGnify_accession.str.split("-").str[2])
    .set_index("species_id")
)

# motu_lineage_string = _motu_taxonomy.Lineage

motu_taxonomy = _motu_taxonomy.Lineage.apply(
    parse_taxonomy_string
)  # .assign(taxonomy_string=motu_lineage_string)
motu_taxonomy

# Prepare Data

In [None]:
motu_depth = (
    pd.read_table(
        "data/group/een/r.proc.gene99_new-v22-agg75.spgc_specgene-ref-t25-p95.species_depth.tsv",
        names=["sample", "species_id", "depth"],
        index_col=["sample", "species_id"],
    )
    .depth.unstack(fill_value=0)
    .rename(columns=str, index=lambda x: "CF_" + str(int(x.split("_")[1])))
    .rename({"CF_15": "CF_11", "CF_11": "CF_15"})  # Sample swap
)
motu_rabund = motu_depth.divide(motu_depth.sum(1), axis=0)

motu_rabund

In [None]:
rotu_counts = pd.read_table(
    "data/group/een/a.proc.zotu_counts.tsv", index_col="#OTU ID"
).rename_axis(index="zotu", columns="sample_id")
rotu_taxonomy = rotu_counts.taxonomy
rotu_counts = rotu_counts.drop(columns=["taxonomy"]).T
rotu_rabund = rotu_counts.divide(rotu_counts.sum(1), axis=0)

sample_rotu_bc_linkage = sp.cluster.hierarchy.linkage(
    rotu_rabund, method="average", metric="braycurtis", optimal_ordering=True
)

# Species Enrichment Analysis

In [None]:
def enrichment_test(d):
    try:
        res = sp.stats.wilcoxon(d["EEN"], d["PostEEN"])
    except ValueError:
        res = (np.nan, np.nan)
    log2_ratio = np.log2(d["PostEEN"] / d["EEN"])
    return pd.Series(
        [log2_ratio.mean(), d["EEN"].mean(), d["PostEEN"].mean(), res[1]],
        index=["log2_ratio", "mean_EEN", "mean_PostEEN", "pvalue"],
    )

In [None]:
d = (
    rotu_rabund.apply(lambda x: x + x.replace({0: np.inf}).min())
    .join(sample[["subject_id", "diet_or_media"]])
    .groupby(["subject_id", "diet_or_media"])
    .mean()
    .stack()
    .unstack("diet_or_media")[["EEN", "PostEEN"]]
    .dropna()
    .assign(log2_ratio=lambda x: np.log2(x["PostEEN"] / x["EEN"]))
    .rename_axis(index=["subject_id", "rotu_id"])
    .groupby(level="rotu_id")
    .apply(enrichment_test)
)
d.sort_values("mean_PostEEN", ascending=False).head(20)
# fig, ax = plt.subplots()
# print(d.log2_ratio.mean())
# print(sp.stats.wilcoxon(d['PostEEN'], d['EEN']))
# ax.hist(d.log2_ratio, bins=20)

# Strain Time-series

## E. coli (Zotu4 / 102506)

In [None]:
motu_id = "102506"
rotu_id = "Zotu4"
drop_strains_thresh = 0.5
ylinthresh = 1e-4

In [None]:
sf_fit = (
    sf.data.World.load(
        f"data/group/een/species/sp-{motu_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts48-s85-seed0.world.nc"
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={"CF_11": "CF_15", "CF_15": "CF_11"})
    .drop_low_abundance_strains(0.01)
    .rename_coords(strain=str)
)
mgtp_error = sf.evaluation.metagenotype_error2(sf_fit, discretized=False)[1]
entrp_error = sf.evaluation.metagenotype_entropy_error(
    sf_fit, discretized=False, p=1, montecarlo_draws=10
)[1]
comm_entrp = sf_fit.community.entropy().to_series()
high_mgtp_error = mgtp_error >= 0.1
high_entrp_error = entrp_error >= 0.2
high_comm_entrp = comm_entrp >= 1.5

# Genotype similarity ordered palette:
strain_linkage = sf_fit.genotype.linkage(optimal_ordering=True)
strain_order = list(
    linkage_order(
        strain_linkage,
        sf_fit.strain.values,
    )
)
strain_order.remove("-1")  # Drop "other" strain.
strain_palette = lib.plot.construct_ordered_palette(
    strain_order,
    cm="rainbow",
)

sf.plot.plot_community(
    sf_fit,
    scalex=0.4,
    scaley=0.6,
    row_linkage_func=lambda w: strain_linkage,
    row_colors=sf_fit.strain.to_series().map(strain_palette),
)

In [None]:
trnsfm_x = lambda x: np.sign(x) * np.sqrt(np.abs(x))
bar_width = 1.0

subject_list = subject_order  # [:3]

fig, axs = plt.subplots(
    nrows=len(subject_list),
    figsize=(10, 4 * len(subject_list)),
    squeeze=False,
    sharex=True,
    sharey=True,
)
for subject_id, ax in zip(subject_list, axs.flatten()):
    ax.set_title(subject_id)
    all_sample_list = (
        sample[lambda x: x.subject_id == subject_id].index.to_series().pipe(list)
    )
    subject_sample_order = (
        sample[lambda x: (x.subject_id == subject_id) & x.sample_type.isin(["human"])]
        .sort_values(["collection_date_relative_een_end"])
        .index.to_series()
        .pipe(list)
    )
    # Cull low abundance strains for each subject

    strain_frac_sample_list = list(set(all_sample_list) & set(sf_fit.sample.values))
    if len(strain_frac_sample_list) == 0:
        print(f"No strain analysis for {subject_id}.")
        comm = []
        _strain_order = []
    else:
        w = (
            sf_fit.sel(sample=strain_frac_sample_list)
            .drop_low_abundance_strains(drop_strains_thresh)
            .rename_coords(strain=str)
        )
        _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
        comm = w.community.to_pandas()

    d0 = (
        sample.loc[subject_sample_order, ["collection_date_relative_een_end"]]
        .assign(
            rotu_rabund=rotu_rabund[rotu_id],
            trnsfm_collection_date_relative_een_end=lambda x: x.collection_date_relative_een_end.apply(
                trnsfm_x
            ),
        )
        .join(comm)
    )

    ax.plot(
        "trnsfm_collection_date_relative_een_end",
        "rotu_rabund",
        data=d0,
        marker="o",
        label="__nolegend__",
        color="k",
        markersize=5,
    )

    # Plot stacked barplot
    ax1 = ax.twinx()
    top_last = 0
    for strain, d1 in d0[_strain_order].T.iterrows():
        ax1.bar(
            x=d0.loc[d1.index].trnsfm_collection_date_relative_een_end,
            height=d1,
            bottom=top_last,
            width=bar_width,
            alpha=1.0,
            color=strain_palette[strain],
            edgecolor="k",
            lw=1,
            label=strain,
        )
        top_last += d1
    ax1.set_yticks([])

    # Put strains behind points:
    ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
    ax.patch.set_visible(False)  # hide the 'canvas'
    ax1.patch.set_visible(True)  # show the 'canvas'

    # Start and end of EEN
    ax.axvline(0, lw=1, linestyle="--", color="k")
    ax.axvline(
        trnsfm_x(subject.een_start_date_relative_een_end.loc[subject_id]),
        lw=1,
        linestyle="--",
        color="k",
    )
    ax1.legend(bbox_to_anchor=(1, 1), ncols=1)
    xtick_pos = np.array([-100, -50, -10, 0, 10, 50, 100, 300])
    ax.set_xticks(trnsfm_x(xtick_pos))
    ax.set_xticklabels(xtick_pos)
    ax.set_yscale("symlog", linthresh=ylinthresh)
    ax.set_ylim(0, 1.0)
    ax1.set_ylim(0, 1.0)
    lib.plot.rotate_xticklabels(ax=ax)

In [None]:
d = (
    rotu_rabund.join(sample)
    .groupby(["subject_id", "diet_or_media"])[rotu_id]
    .mean()
    .unstack()[["EEN", "PostEEN"]]
    .dropna()
    .assign(log2_ratio=lambda x: np.log2(x["PostEEN"] / x["EEN"]))
)

print(d.log2_ratio.mean())
print(sp.stats.wilcoxon(d["PostEEN"], d["EEN"]))
plt.hist(d.log2_ratio, bins=20)

In [None]:
max_c_value = np.abs(d.log2_ratio).max()

fig, ax = plt.subplots(figsize=(3, 8))

for subject_id, (een_rabund, post_rabund, log2_ratio, c) in d.assign(
    c=lambda x: ((x.log2_ratio / max_c_value) + 1) / 2
).iterrows():
    ax.plot([0, 1], [een_rabund, post_rabund], c=mpl.cm.coolwarm(c), lw=4)
ax.set_yscale("log")
ax.set_xticks([0, 1])
ax.set_xticklabels(["EEN", "PostEEN"])
ax.set_xlim(-0.1, 1.1)
ax.set_ylim(1e-5, 1.0)

In [None]:
d0 = (
    sample.loc[
        lambda x: (
            True
            # x.index.isin(sf_fit.sample.values)
            & x.sample_type.isin(["human", "Fermenter", "mouse"])
            & x.subject_id.isin(["A", "B", "H"])
            # & (x.sample_type == "Fermenter")
        ),
        [
            "subject_id",
            "collection_date_relative_een_end",
            "sample_type",
            "diet_or_media",
            "mouse_genotype",
            "source_samples",
            "status_mouse_inflamed",
            "full_label",
        ],
    ]
    .sort_values(
        [
            "subject_id",
            "collection_date_relative_een_end",
            "sample_type",
            "diet_or_media",
            "mouse_genotype",
            "source_samples",
            "status_mouse_inflamed",
        ]
    )
    .assign(
        rotu_rabund=rotu_rabund[rotu_id],
    )
)

sample_type_order = ["human", "Fermenter", "mouse"]
subject_order = ["A", "B", "H"]


_grid_sample_counts = (
    d0[["subject_id", "sample_type"]]
    .value_counts()
    .unstack()
    .reindex(columns=sample_type_order)
)
fig, axs = plt.subplots(
    *_grid_sample_counts.shape,
    figsize=(90, 15),
    width_ratios=_grid_sample_counts.max().values,
    sharey=True,
    gridspec_kw=dict(wspace=0.1, hspace=3),
)

for subject_id, ax_row in zip(subject_order, axs):
    for sample_type, ax in zip(sample_type_order, ax_row):
        d1 = d0[
            lambda x: (x.sample_type == sample_type) & (x.subject_id == subject_id)
        ].assign(xpos=lambda x: np.arange(len(x.index)))
        ax.scatter(
            "xpos", "rotu_rabund", data=d1, color="k", s=10, label="__nolegend__"
        )
        # ax.set_aspect(700, adjustable="datalim", anchor="NW")
        ax.set_ylim(-1e-5, 1)
        ax.set_yscale("symlog", linthresh=1e-4, linscale=0.1)
        ax.set_xlim(-0.5, _grid_sample_counts[sample_type].max())
        ax.set_xticks(d1.xpos)
        ax.set_xticklabels(d1.full_label)

        strain_frac_sample_list = list(set(d1.index) & set(sf_fit.sample.values))
        if len(strain_frac_sample_list) == 0:
            print(f"No strain analysis for {subject_id}.")
            comm = []
            _strain_order = []
        else:
            w = (
                sf_fit.sel(sample=strain_frac_sample_list)
                .drop_low_abundance_strains(drop_strains_thresh)
                .rename_coords(strain=str)
            )
            _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
            comm = w.community.to_pandas()
        d2 = d1.join(comm)
        # Plot stacked barplot
        ax1 = ax.twinx()
        top_last = 0
        for strain in _strain_order:
            ax1.bar(
                x="xpos",
                height=strain,
                data=d2,
                bottom=top_last,
                width=bar_width,
                alpha=1.0,
                color=strain_palette[strain],
                edgecolor="k",
                lw=1,
                label="__nolegend__",
            )
            top_last += d2[strain]
            ax.scatter(
                [], [], color=strain_palette[strain], label=strain, marker="s", s=80
            )
        ax1.set_yticks([])
        # Put strains behind points:
        ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
        ax.patch.set_visible(False)  # hide the 'canvas'
        ax1.patch.set_visible(True)  # show the 'canvas'

        ax1.set_ylim(0, 1)
        lib.plot.rotate_xticklabels(ax=ax)
        if sample_type == "human":
            ax.legend(loc="upper right")


# fig.tight_layout()

# axs[0, 0].plot([0, 100], [0, 1])

# axs[0, 1].plot([0, 10], [0, 1])
# axs[0, 1].set_aspect(10)

# axs[0, 2].plot([0, 20], [0, 1])
# axs[0, 2].set_aspect(5)

In [None]:
_grid_sample_counts.max().values

In [None]:
for subject_id, sample_type in product(["A", "B", "H"], ["Fermenter", "mouse"]):
    all_sample_list = (
        sample[lambda x: (x.sample_type == sample_type) & (x.subject_id == subject_id)]
        .index.to_series()
        .pipe(list)
    )
    subject_sample_order = (
        sample[
            lambda x: (x.sample_type == sample_type)
            & (x.subject_id == subject_id)
            & (x.index.isin(sf_fit.sample.values))
        ]
        .sort_values(
            [
                "sample_type",
                "diet_or_media",
                "mouse_genotype",
                "source_samples",
                "status_mouse_inflamed",
            ]
        )
        .index.to_series()
        .pipe(list)
    )
    # Cull low abundance strains for each subject

    strain_frac_sample_list = list(set(all_sample_list) & set(sf_fit.sample.values))
    if len(strain_frac_sample_list) == 0:
        print(f"No strain analysis for {subject_id}.")
        comm = []
        _strain_order = []
    else:
        w = (
            sf_fit.sel(sample=strain_frac_sample_list)
            .drop_low_abundance_strains(drop_strains_thresh)
            .rename_coords(strain=str)
        )
        _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
        comm = w.community.to_pandas()

    d0 = (
        sample.loc[
            subject_sample_order,
            [
                "sample_type",
                "diet_or_media",
                "mouse_genotype",
                "source_samples",
                "status_mouse_inflamed",
                "full_label",
            ],
        ]
        .assign(
            rotu_rabund=rotu_rabund[rotu_id],
        )
        .join(comm)
        .assign(xpos=lambda x: np.arange(len(x.index)))
    )

    fig, ax = plt.subplots(
        figsize=(0.4 * len(strain_frac_sample_list), 4),
    )

    ax.plot(
        "rotu_rabund",
        data=d0,
        marker="o",
        linestyle="",
        label="__nolegend__",
        color="k",
        markersize=5,
    )

    # Plot stacked barplot
    ax1 = ax.twinx()
    top_last = 0
    for strain, d1 in d0[_strain_order].T.iterrows():
        ax1.bar(
            x=d0.loc[d1.index].xpos,
            height=d1,
            width=1,
            bottom=top_last,
            alpha=1.0,
            color=strain_palette[strain],
            edgecolor="k",
            lw=1,
            label=strain,
        )
        top_last += d1
    ax1.set_yticks([])

    # Put strains behind points:
    ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
    ax.patch.set_visible(False)  # hide the 'canvas'
    ax1.patch.set_visible(True)  # show the 'canvas'

    # # Start and end of EEN
    # ax.axvline(0, lw=1, linestyle="--", color="k")
    # ax.axvline(
    #     trnsfm_x(subject.een_start_date_relative_een_end.loc[subject_id]),
    #     lw=1,
    #     linestyle="--",
    #     color="k",
    # )
    ax1.legend(bbox_to_anchor=(1, 1), ncols=1)

    ax.set_title(subject_id)
    ax.set_xticklabels(d0.full_label)

    # xtick_pos = np.array([-100, -50, -10, 0, 10, 50, 100, 300])
    # ax.set_xticks(trnsfm_x(xtick_pos))
    # ax.set_xticklabels(xtick_pos)
    ax.set_yscale("symlog", linthresh=ylinthresh)
    ax.set_ylim(0, 1.0)
    ax1.set_ylim(0, 1.0)
    ax.set_xlim(-0.5, len(strain_frac_sample_list) - 0.5)
    lib.plot.rotate_xticklabels(ax=ax)

In [None]:
(
    sample.loc[
        lambda x: (
            x.index.isin(sf_fit.sample.values)
            & x.sample_type.isin(["mouse"])
            & x.subject_id.isin(["A", "B", "H"])
            # & (x.sample_type == "Fermenter")
        ),
        [
            "subject_id",
            "collection_date_relative_een_end",
            "sample_type",
            "diet_or_media",
            "mouse_genotype",
            "source_samples",
            "status_mouse_inflamed",
        ],
    ]
    .sort_values(
        [
            "subject_id",
            "collection_date_relative_een_end",
            "sample_type",
            "diet_or_media",
            "mouse_genotype",
            "source_samples",
            "status_mouse_inflamed",
        ]
    )
    .assign(
        rotu_rabund=rotu_rabund[rotu_id],
    )
    # .join(comm)
    .assign(xpos=lambda x: np.arange(len(x.index)))
)

## E. lenta (Zotu172 / 102544)

In [None]:
motu_id = "102544"
rotu_id = "Zotu172"
drop_strains_thresh = 0.5
ylinthresh = 1e-4

In [None]:
sf_fit = (
    sf.data.World.load(
        f"data/group/een/species/sp-{motu_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts48-s85-seed0.world.nc"
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={"CF_11": "CF_15", "CF_15": "CF_11"})
    .drop_low_abundance_strains(0.01)
    .rename_coords(strain=str)
)
mgtp_error = sf.evaluation.metagenotype_error2(sf_fit, discretized=False)[1]
entrp_error = sf.evaluation.metagenotype_entropy_error(
    sf_fit, discretized=False, p=1, montecarlo_draws=10
)[1]
comm_entrp = sf_fit.community.entropy().to_series()
high_mgtp_error = mgtp_error >= 0.1
high_entrp_error = entrp_error >= 0.2
high_comm_entrp = comm_entrp >= 1.5

# Genotype similarity ordered palette:
strain_linkage = sf_fit.genotype.linkage(optimal_ordering=True)
strain_order = list(
    linkage_order(
        strain_linkage,
        sf_fit.strain.values,
    )
)
strain_order.remove("-1")  # Drop "other" strain.
strain_palette = lib.plot.construct_ordered_palette(
    strain_order,
    cm="rainbow",
)

sf.plot.plot_community(
    sf_fit,
    scalex=0.4,
    scaley=0.6,
    row_linkage_func=lambda w: strain_linkage,
    row_colors=sf_fit.strain.to_series().map(strain_palette),
)

In [None]:
trnsfm_x = lambda x: np.sign(x) * np.sqrt(np.abs(x))
bar_width = 1.0

subject_list = subject_order  # [:3]

fig, axs = plt.subplots(
    nrows=len(subject_list),
    figsize=(10, 4 * len(subject_list)),
    squeeze=False,
    sharex=True,
    sharey=True,
)
for subject_id, ax in zip(subject_list, axs.flatten()):
    ax.set_title(subject_id)
    all_sample_list = (
        sample[lambda x: x.subject_id == subject_id].index.to_series().pipe(list)
    )
    subject_sample_order = (
        sample[lambda x: (x.subject_id == subject_id) & x.sample_type.isin(["human"])]
        .sort_values(["collection_date_relative_een_end"])
        .index.to_series()
        .pipe(list)
    )
    # Cull low abundance strains for each subject

    strain_frac_sample_list = list(set(all_sample_list) & set(sf_fit.sample.values))
    if len(strain_frac_sample_list) == 0:
        print(f"No strain analysis for {subject_id}.")
        comm = []
        _strain_order = []
    else:
        w = (
            sf_fit.sel(sample=strain_frac_sample_list)
            .drop_low_abundance_strains(drop_strains_thresh)
            .rename_coords(strain=str)
        )
        _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
        comm = w.community.to_pandas()

    d0 = (
        sample.loc[subject_sample_order, ["collection_date_relative_een_end"]]
        .assign(
            rotu_rabund=rotu_rabund[rotu_id],
            trnsfm_collection_date_relative_een_end=lambda x: x.collection_date_relative_een_end.apply(
                trnsfm_x
            ),
        )
        .join(comm)
    )

    ax.plot(
        "trnsfm_collection_date_relative_een_end",
        "rotu_rabund",
        data=d0,
        marker="o",
        label="__nolegend__",
        color="k",
        markersize=5,
    )

    # Plot stacked barplot
    ax1 = ax.twinx()
    top_last = 0
    for strain, d1 in d0[_strain_order].T.iterrows():
        ax1.bar(
            x=d0.loc[d1.index].trnsfm_collection_date_relative_een_end,
            height=d1,
            bottom=top_last,
            width=bar_width,
            alpha=1.0,
            color=strain_palette[strain],
            edgecolor="k",
            lw=1,
            label=strain,
        )
        top_last += d1
    ax1.set_yticks([])

    # Put strains behind points:
    ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
    ax.patch.set_visible(False)  # hide the 'canvas'
    ax1.patch.set_visible(True)  # show the 'canvas'

    # Start and end of EEN
    ax.axvline(0, lw=1, linestyle="--", color="k")
    ax.axvline(
        trnsfm_x(subject.een_start_date_relative_een_end.loc[subject_id]),
        lw=1,
        linestyle="--",
        color="k",
    )
    ax1.legend(bbox_to_anchor=(1, 1), ncols=1)
    xtick_pos = np.array([-100, -50, -10, 0, 10, 50, 100, 300])
    ax.set_xticks(trnsfm_x(xtick_pos))
    ax.set_xticklabels(xtick_pos)
    ax.set_yscale("symlog", linthresh=ylinthresh)
    ax.set_ylim(0, 1.0)
    ax1.set_ylim(0, 1.0)
    lib.plot.rotate_xticklabels(ax=ax)

In [None]:
trnsfm_x = lambda x: np.sign(x) * np.sqrt(np.abs(x))
bar_width = 1.0

subject_list = ["A", "B", "H", "N", "M", "S"]

fig, axs = plt.subplots(
    nrows=len(subject_list) // 2,
    ncols=2,
    figsize=(10 * 2, 4 * len(subject_list) // 2),
    squeeze=False,
    sharex=True,
    sharey=True,
)
for subject_id, ax in zip(subject_list, axs.flatten()):
    ax.set_title(subject_id)
    all_sample_list = (
        sample[lambda x: x.subject_id == subject_id].index.to_series().pipe(list)
    )
    subject_sample_order = (
        sample[lambda x: (x.subject_id == subject_id) & x.sample_type.isin(["human"])]
        .sort_values(["collection_date_relative_een_end"])
        .index.to_series()
        .pipe(list)
    )
    # Cull low abundance strains for each subject

    strain_frac_sample_list = list(set(all_sample_list) & set(sf_fit.sample.values))
    if len(strain_frac_sample_list) == 0:
        print(f"No strain analysis for {subject_id}.")
        comm = []
        _strain_order = []
    else:
        w = (
            sf_fit.sel(sample=strain_frac_sample_list)
            .drop_low_abundance_strains(drop_strains_thresh)
            .rename_coords(strain=str)
        )
        _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
        comm = w.community.to_pandas()

    d0 = (
        sample.loc[subject_sample_order, ["collection_date_relative_een_end"]]
        .assign(
            rotu_rabund=rotu_rabund[rotu_id],
            trnsfm_collection_date_relative_een_end=lambda x: x.collection_date_relative_een_end.apply(
                trnsfm_x
            ),
        )
        .join(comm)
    )

    ax.plot(
        "trnsfm_collection_date_relative_een_end",
        "rotu_rabund",
        data=d0,
        marker="o",
        label="__nolegend__",
        color="k",
        markersize=5,
    )

    # Plot stacked barplot
    ax1 = ax.twinx()
    top_last = 0
    for strain, d1 in d0[_strain_order].T.iterrows():
        ax1.bar(
            x=d0.loc[d1.index].trnsfm_collection_date_relative_een_end,
            height=d1,
            bottom=top_last,
            width=bar_width,
            alpha=1.0,
            color=strain_palette[strain],
            edgecolor="k",
            lw=1,
            label=strain,
        )
        top_last += d1
    ax1.set_yticks([])

    # Put strains behind points:
    ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
    ax.patch.set_visible(False)  # hide the 'canvas'
    ax1.patch.set_visible(True)  # show the 'canvas'

    # Start and end of EEN
    ax.axvline(0, lw=1, linestyle="--", color="k")
    ax.axvline(
        trnsfm_x(subject.een_start_date_relative_een_end.loc[subject_id]),
        lw=1,
        linestyle="--",
        color="k",
    )
    # ax1.legend()
    xtick_pos = np.array([-100, -50, -10, 0, 10, 50, 100, 300])
    ax.set_xticks(trnsfm_x(xtick_pos))
    ax.set_xticklabels(xtick_pos)
    ax.set_yscale("symlog", linthresh=ylinthresh)
    ax.set_ylim(0, 1.0)
    ax1.set_ylim(0, 1.0)
    lib.plot.rotate_xticklabels(ax=ax)

In [None]:
d = (
    rotu_rabund.apply(lambda x: x + x.replace({0: np.inf}).min())
    .join(sample)
    .groupby(["subject_id", "diet_or_media"])[rotu_id]
    .mean()
    .unstack()[["EEN", "PostEEN"]]
    .dropna()
    .assign(log2_ratio=lambda x: np.log2(x["PostEEN"] / x["EEN"]))
)

fig, ax = plt.subplots()
print(d.log2_ratio.mean())
print(sp.stats.wilcoxon(d["PostEEN"], d["EEN"]))
ax.hist(d.log2_ratio, bins=20)

fig, ax = plt.subplots(figsize=(3, 8))
max_c_value = np.abs(d.log2_ratio).max()
for subject_id, (een_rabund, post_rabund, log2_ratio, c) in d.assign(
    c=lambda x: ((x.log2_ratio / max_c_value) + 1) / 2
).iterrows():
    ax.plot([0, 1], [een_rabund, post_rabund], c=mpl.cm.coolwarm(c), lw=4)
ax.set_yscale("log")
ax.set_xticks([0, 1])
ax.set_xticklabels(["EEN", "PostEEN"])
ax.set_xlim(-0.1, 1.1)

In [None]:
for subject_id, sample_type in product(["A", "B", "H"], ["Fermenter", "mouse"]):
    all_sample_list = (
        sample[lambda x: (x.sample_type == sample_type) & (x.subject_id == subject_id)]
        .index.to_series()
        .pipe(list)
    )
    subject_sample_order = (
        sample[
            lambda x: (x.sample_type == sample_type)
            & (x.subject_id == subject_id)
            & (x.index.isin(sf_fit.sample.values))
        ]
        .sort_values(
            [
                "sample_type",
                "diet_or_media",
                "mouse_genotype",
                "source_samples",
                "status_mouse_inflamed",
            ]
        )
        .index.to_series()
        .pipe(list)
    )
    # Cull low abundance strains for each subject

    strain_frac_sample_list = list(set(all_sample_list) & set(sf_fit.sample.values))
    if len(strain_frac_sample_list) == 0:
        print(f"No strain analysis for {subject_id}.")
        comm = []
        _strain_order = []
    else:
        w = (
            sf_fit.sel(sample=strain_frac_sample_list)
            .drop_low_abundance_strains(drop_strains_thresh)
            .rename_coords(strain=str)
        )
        _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
        comm = w.community.to_pandas()

    d0 = (
        sample.loc[
            subject_sample_order,
            [
                "sample_type",
                "diet_or_media",
                "mouse_genotype",
                "source_samples",
                "status_mouse_inflamed",
            ],
        ]
        .assign(
            rotu_rabund=rotu_rabund[rotu_id],
        )
        .join(comm)
        .assign(xpos=lambda x: np.arange(len(x.index)))
    )

    fig, ax = plt.subplots(
        figsize=(0.4 * len(strain_frac_sample_list), 4),
    )

    ax.plot(
        "rotu_rabund",
        data=d0,
        marker="o",
        linestyle="",
        label="__nolegend__",
        color="k",
        markersize=5,
    )

    # Plot stacked barplot
    ax1 = ax.twinx()
    top_last = 0
    for strain, d1 in d0[_strain_order].T.iterrows():
        ax1.bar(
            x=d0.loc[d1.index].xpos,
            height=d1,
            width=1,
            bottom=top_last,
            alpha=1.0,
            color=strain_palette[strain],
            edgecolor="k",
            lw=1,
            label=strain,
        )
        top_last += d1
    ax1.set_yticks([])

    # Put strains behind points:
    ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
    ax.patch.set_visible(False)  # hide the 'canvas'
    ax1.patch.set_visible(True)  # show the 'canvas'

    # # Start and end of EEN
    # ax.axvline(0, lw=1, linestyle="--", color="k")
    # ax.axvline(
    #     trnsfm_x(subject.een_start_date_relative_een_end.loc[subject_id]),
    #     lw=1,
    #     linestyle="--",
    #     color="k",
    # )
    ax1.legend(bbox_to_anchor=(1, 1), ncols=2)

    ax.set_title(subject_id)

    # xtick_pos = np.array([-100, -50, -10, 0, 10, 50, 100, 300])
    # ax.set_xticks(trnsfm_x(xtick_pos))
    # ax.set_xticklabels(xtick_pos)
    ax.set_yscale("symlog", linthresh=ylinthresh)
    ax.set_ylim(0, 1.0)
    ax.set_xlim(-0.5, len(strain_frac_sample_list) - 0.5)
    lib.plot.rotate_xticklabels(ax=ax)

## F. plautii (Zotu49 / 100099)

In [None]:
motu_id = "100099"
rotu_id = "Zotu49"
drop_strains_thresh = 0.5
ylinthresh = 1e-4

In [None]:
sf_fit = (
    sf.data.World.load(
        f"data/group/een/species/sp-{motu_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts48-s85-seed0.world.nc"
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={"CF_11": "CF_15", "CF_15": "CF_11"})
    .drop_low_abundance_strains(0.01)
    .rename_coords(strain=str)
)
mgtp_error = sf.evaluation.metagenotype_error2(sf_fit, discretized=False)[1]
entrp_error = sf.evaluation.metagenotype_entropy_error(
    sf_fit, discretized=False, p=1, montecarlo_draws=10
)[1]
comm_entrp = sf_fit.community.entropy().to_series()
high_mgtp_error = mgtp_error >= 0.1
high_entrp_error = entrp_error >= 0.2
high_comm_entrp = comm_entrp >= 1.5

# Genotype similarity ordered palette:
strain_linkage = sf_fit.genotype.linkage(optimal_ordering=True)
strain_order = list(
    linkage_order(
        strain_linkage,
        sf_fit.strain.values,
    )
)
strain_order.remove("-1")  # Drop "other" strain.
strain_palette = lib.plot.construct_ordered_palette(
    strain_order,
    cm="rainbow",
)

sf.plot.plot_community(
    sf_fit,
    scalex=0.4,
    scaley=0.6,
    row_linkage_func=lambda w: strain_linkage,
    row_colors=sf_fit.strain.to_series().map(strain_palette),
)

In [None]:
trnsfm_x = lambda x: np.sign(x) * np.sqrt(np.abs(x))
bar_width = 1.0

subject_list = subject_order  # [:3]

fig, axs = plt.subplots(
    nrows=len(subject_list),
    figsize=(10, 4 * len(subject_list)),
    squeeze=False,
    sharex=True,
    sharey=True,
)
for subject_id, ax in zip(subject_list, axs.flatten()):
    ax.set_title(subject_id)
    all_sample_list = (
        sample[lambda x: x.subject_id == subject_id].index.to_series().pipe(list)
    )
    subject_sample_order = (
        sample[lambda x: (x.subject_id == subject_id) & x.sample_type.isin(["human"])]
        .sort_values(["collection_date_relative_een_end"])
        .index.to_series()
        .pipe(list)
    )
    # Cull low abundance strains for each subject

    strain_frac_sample_list = list(set(all_sample_list) & set(sf_fit.sample.values))
    if len(strain_frac_sample_list) == 0:
        print(f"No strain analysis for {subject_id}.")
        comm = []
        _strain_order = []
    else:
        w = (
            sf_fit.sel(sample=strain_frac_sample_list)
            .drop_low_abundance_strains(drop_strains_thresh)
            .rename_coords(strain=str)
        )
        _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
        comm = w.community.to_pandas()

    d0 = (
        sample.loc[subject_sample_order, ["collection_date_relative_een_end"]]
        .assign(
            rotu_rabund=rotu_rabund[rotu_id],
            trnsfm_collection_date_relative_een_end=lambda x: x.collection_date_relative_een_end.apply(
                trnsfm_x
            ),
        )
        .join(comm)
    )

    ax.plot(
        "trnsfm_collection_date_relative_een_end",
        "rotu_rabund",
        data=d0,
        marker="o",
        label="__nolegend__",
        color="k",
        markersize=5,
    )

    # Plot stacked barplot
    ax1 = ax.twinx()
    top_last = 0
    for strain, d1 in d0[_strain_order].T.iterrows():
        ax1.bar(
            x=d0.loc[d1.index].trnsfm_collection_date_relative_een_end,
            height=d1,
            bottom=top_last,
            width=bar_width,
            alpha=1.0,
            color=strain_palette[strain],
            edgecolor="k",
            lw=1,
            label=strain,
        )
        top_last += d1
    ax1.set_yticks([])

    # Put strains behind points:
    ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
    ax.patch.set_visible(False)  # hide the 'canvas'
    ax1.patch.set_visible(True)  # show the 'canvas'

    # Start and end of EEN
    ax.axvline(0, lw=1, linestyle="--", color="k")
    ax.axvline(
        trnsfm_x(subject.een_start_date_relative_een_end.loc[subject_id]),
        lw=1,
        linestyle="--",
        color="k",
    )
    ax1.legend(bbox_to_anchor=(1, 1), ncols=1)
    xtick_pos = np.array([-100, -50, -10, 0, 10, 50, 100, 300])
    ax.set_xticks(trnsfm_x(xtick_pos))
    ax.set_xticklabels(xtick_pos)
    ax.set_yscale("symlog", linthresh=ylinthresh)
    ax.set_ylim(0, 1.0)
    ax1.set_ylim(0, 1.0)
    lib.plot.rotate_xticklabels(ax=ax)

In [None]:
d = (
    rotu_rabund.apply(lambda x: x + x.replace({0: np.inf}).min())
    .join(sample)
    .groupby(["subject_id", "diet_or_media"])[rotu_id]
    .mean()
    .unstack()[["EEN", "PostEEN"]]
    .dropna()
    .assign(log2_ratio=lambda x: np.log2(x["PostEEN"] / x["EEN"]))
)

fig, ax = plt.subplots()
print(d.log2_ratio.mean())
print(sp.stats.wilcoxon(d["PostEEN"], d["EEN"]))
ax.hist(d.log2_ratio, bins=20)

fig, ax = plt.subplots(figsize=(3, 8))
max_c_value = np.abs(d.log2_ratio).max()
for subject_id, (een_rabund, post_rabund, log2_ratio, c) in d.assign(
    c=lambda x: ((x.log2_ratio / max_c_value) + 1) / 2
).iterrows():
    ax.plot([0, 1], [een_rabund, post_rabund], c=mpl.cm.coolwarm(c), lw=4)
ax.set_yscale("log")
ax.set_xticks([0, 1])
ax.set_xticklabels(["EEN", "PostEEN"])
ax.set_xlim(-0.1, 1.1)

In [None]:
for subject_id, sample_type in product(["A", "B", "H"], ["Fermenter", "mouse"]):
    all_sample_list = (
        sample[lambda x: (x.sample_type == sample_type) & (x.subject_id == subject_id)]
        .index.to_series()
        .pipe(list)
    )
    subject_sample_order = (
        sample[
            lambda x: (x.sample_type == sample_type)
            & (x.subject_id == subject_id)
            & (x.index.isin(sf_fit.sample.values))
        ]
        .sort_values(
            [
                "sample_type",
                "diet_or_media",
                "mouse_genotype",
                "source_samples",
                "status_mouse_inflamed",
            ]
        )
        .index.to_series()
        .pipe(list)
    )
    # Cull low abundance strains for each subject

    strain_frac_sample_list = list(set(all_sample_list) & set(sf_fit.sample.values))
    if len(strain_frac_sample_list) == 0:
        print(f"No strain analysis for {subject_id}.")
        comm = []
        _strain_order = []
    else:
        w = (
            sf_fit.sel(sample=strain_frac_sample_list)
            .drop_low_abundance_strains(drop_strains_thresh)
            .rename_coords(strain=str)
        )
        _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
        comm = w.community.to_pandas()

    d0 = (
        sample.loc[
            subject_sample_order,
            [
                "sample_type",
                "diet_or_media",
                "mouse_genotype",
                "source_samples",
                "status_mouse_inflamed",
            ],
        ]
        .assign(
            rotu_rabund=rotu_rabund[rotu_id],
        )
        .join(comm)
        .assign(xpos=lambda x: np.arange(len(x.index)))
    )

    fig, ax = plt.subplots(
        figsize=(0.4 * len(strain_frac_sample_list), 4),
    )

    ax.plot(
        "rotu_rabund",
        data=d0,
        marker="o",
        linestyle="",
        label="__nolegend__",
        color="k",
        markersize=5,
    )

    # Plot stacked barplot
    ax1 = ax.twinx()
    top_last = 0
    for strain, d1 in d0[_strain_order].T.iterrows():
        ax1.bar(
            x=d0.loc[d1.index].xpos,
            height=d1,
            width=1,
            bottom=top_last,
            alpha=1.0,
            color=strain_palette[strain],
            edgecolor="k",
            lw=1,
            label=strain,
        )
        top_last += d1
    ax1.set_yticks([])

    # Put strains behind points:
    ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
    ax.patch.set_visible(False)  # hide the 'canvas'
    ax1.patch.set_visible(True)  # show the 'canvas'

    # # Start and end of EEN
    # ax.axvline(0, lw=1, linestyle="--", color="k")
    # ax.axvline(
    #     trnsfm_x(subject.een_start_date_relative_een_end.loc[subject_id]),
    #     lw=1,
    #     linestyle="--",
    #     color="k",
    # )
    ax1.legend(bbox_to_anchor=(1, 1), ncols=2)

    ax.set_title(subject_id)

    # xtick_pos = np.array([-100, -50, -10, 0, 10, 50, 100, 300])
    # ax.set_xticks(trnsfm_x(xtick_pos))
    # ax.set_xticklabels(xtick_pos)
    ax.set_yscale("symlog", linthresh=ylinthresh)
    ax.set_ylim(0, 1.0)
    ax.set_xlim(-0.5, len(strain_frac_sample_list) - 0.5)
    lib.plot.rotate_xticklabels(ax=ax)

## Bacteroides dorei (Zotu1 / 102478)

In [None]:
motu_id = "102478"
rotu_id = "Zotu1"
drop_strains_thresh = 0.5
ylinthresh = 1e-4

In [None]:
sf_fit = (
    sf.data.World.load(
        f"data/group/een/species/sp-{motu_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts48-s85-seed0.world.nc"
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={"CF_11": "CF_15", "CF_15": "CF_11"})
    .drop_low_abundance_strains(0.01)
    .rename_coords(strain=str)
)
mgtp_error = sf.evaluation.metagenotype_error2(sf_fit, discretized=False)[1]
entrp_error = sf.evaluation.metagenotype_entropy_error(
    sf_fit, discretized=False, p=1, montecarlo_draws=10
)[1]
comm_entrp = sf_fit.community.entropy().to_series()
high_mgtp_error = mgtp_error >= 0.1
high_entrp_error = entrp_error >= 0.2
high_comm_entrp = comm_entrp >= 1.5

# Genotype similarity ordered palette:
strain_linkage = sf_fit.genotype.linkage(optimal_ordering=True)
strain_order = list(
    linkage_order(
        strain_linkage,
        sf_fit.strain.values,
    )
)
strain_order.remove("-1")  # Drop "other" strain.
strain_palette = lib.plot.construct_ordered_palette(
    strain_order,
    cm="rainbow",
)

sf.plot.plot_community(
    sf_fit,
    scalex=0.4,
    scaley=0.6,
    row_linkage_func=lambda w: strain_linkage,
    row_colors=sf_fit.strain.to_series().map(strain_palette),
)

In [None]:
trnsfm_x = lambda x: np.sign(x) * np.sqrt(np.abs(x))
bar_width = 1.0

subject_list = subject_order  # [:3]

fig, axs = plt.subplots(
    nrows=len(subject_list),
    figsize=(10, 4 * len(subject_list)),
    squeeze=False,
    sharex=True,
    sharey=True,
)
for subject_id, ax in zip(subject_list, axs.flatten()):
    ax.set_title(subject_id)
    all_sample_list = (
        sample[lambda x: x.subject_id == subject_id].index.to_series().pipe(list)
    )
    subject_sample_order = (
        sample[lambda x: (x.subject_id == subject_id) & x.sample_type.isin(["human"])]
        .sort_values(["collection_date_relative_een_end"])
        .index.to_series()
        .pipe(list)
    )
    # Cull low abundance strains for each subject

    strain_frac_sample_list = list(set(all_sample_list) & set(sf_fit.sample.values))
    if len(strain_frac_sample_list) == 0:
        print(f"No strain analysis for {subject_id}.")
        comm = []
        _strain_order = []
    else:
        w = (
            sf_fit.sel(sample=strain_frac_sample_list)
            .drop_low_abundance_strains(drop_strains_thresh)
            .rename_coords(strain=str)
        )
        _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
        comm = w.community.to_pandas()

    d0 = (
        sample.loc[subject_sample_order, ["collection_date_relative_een_end"]]
        .assign(
            rotu_rabund=rotu_rabund[rotu_id],
            trnsfm_collection_date_relative_een_end=lambda x: x.collection_date_relative_een_end.apply(
                trnsfm_x
            ),
        )
        .join(comm)
    )

    ax.plot(
        "trnsfm_collection_date_relative_een_end",
        "rotu_rabund",
        data=d0,
        marker="o",
        label="__nolegend__",
        color="k",
        markersize=5,
    )

    # Plot stacked barplot
    ax1 = ax.twinx()
    top_last = 0
    for strain, d1 in d0[_strain_order].T.iterrows():
        ax1.bar(
            x=d0.loc[d1.index].trnsfm_collection_date_relative_een_end,
            height=d1,
            bottom=top_last,
            width=bar_width,
            alpha=1.0,
            color=strain_palette[strain],
            edgecolor="k",
            lw=1,
            label=strain,
        )
        top_last += d1
    ax1.set_yticks([])

    # Put strains behind points:
    ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
    ax.patch.set_visible(False)  # hide the 'canvas'
    ax1.patch.set_visible(True)  # show the 'canvas'

    # Start and end of EEN
    ax.axvline(0, lw=1, linestyle="--", color="k")
    ax.axvline(
        trnsfm_x(subject.een_start_date_relative_een_end.loc[subject_id]),
        lw=1,
        linestyle="--",
        color="k",
    )
    ax1.legend(bbox_to_anchor=(1, 1), ncols=1)
    xtick_pos = np.array([-100, -50, -10, 0, 10, 50, 100, 300])
    ax.set_xticks(trnsfm_x(xtick_pos))
    ax.set_xticklabels(xtick_pos)
    ax.set_yscale("symlog", linthresh=ylinthresh)
    ax.set_ylim(0, 1.0)
    ax1.set_ylim(0, 1.0)
    lib.plot.rotate_xticklabels(ax=ax)

In [None]:
d = (
    rotu_rabund.apply(lambda x: x + x.replace({0: np.inf}).min())
    .join(sample)
    .groupby(["subject_id", "diet_or_media"])[rotu_id]
    .mean()
    .unstack()[["EEN", "PostEEN"]]
    .dropna()
    .assign(log2_ratio=lambda x: np.log2(x["PostEEN"] / x["EEN"]))
)

fig, ax = plt.subplots()
print(d.log2_ratio.mean())
print(sp.stats.wilcoxon(d["PostEEN"], d["EEN"]))
ax.hist(d.log2_ratio, bins=20)

fig, ax = plt.subplots(figsize=(3, 8))
max_c_value = np.abs(d.log2_ratio).max()
for subject_id, (een_rabund, post_rabund, log2_ratio, c) in d.assign(
    c=lambda x: ((x.log2_ratio / max_c_value) + 1) / 2
).iterrows():
    ax.plot([0, 1], [een_rabund, post_rabund], c=mpl.cm.coolwarm(c), lw=4)
ax.set_yscale("log")
ax.set_xticks([0, 1])
ax.set_xticklabels(["EEN", "PostEEN"])
ax.set_xlim(-0.1, 1.1)

In [None]:
for subject_id, sample_type in product(["A", "B", "H"], ["Fermenter", "mouse"]):
    all_sample_list = (
        sample[lambda x: (x.sample_type == sample_type) & (x.subject_id == subject_id)]
        .index.to_series()
        .pipe(list)
    )
    subject_sample_order = (
        sample[
            lambda x: (x.sample_type == sample_type)
            & (x.subject_id == subject_id)
            & (x.index.isin(sf_fit.sample.values))
        ]
        .sort_values(
            [
                "sample_type",
                "diet_or_media",
                "mouse_genotype",
                "source_samples",
                "status_mouse_inflamed",
            ]
        )
        .index.to_series()
        .pipe(list)
    )
    # Cull low abundance strains for each subject

    strain_frac_sample_list = list(set(all_sample_list) & set(sf_fit.sample.values))
    if len(strain_frac_sample_list) == 0:
        print(f"No strain analysis for {subject_id}.")
        comm = []
        _strain_order = []
    else:
        w = (
            sf_fit.sel(sample=strain_frac_sample_list)
            .drop_low_abundance_strains(drop_strains_thresh)
            .rename_coords(strain=str)
        )
        _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
        comm = w.community.to_pandas()

    d0 = (
        sample.loc[
            subject_sample_order,
            [
                "sample_type",
                "diet_or_media",
                "mouse_genotype",
                "source_samples",
                "status_mouse_inflamed",
            ],
        ]
        .assign(
            rotu_rabund=rotu_rabund[rotu_id],
        )
        .join(comm)
        .assign(xpos=lambda x: np.arange(len(x.index)))
    )

    fig, ax = plt.subplots(
        figsize=(0.4 * len(strain_frac_sample_list), 4),
    )

    ax.plot(
        "rotu_rabund",
        data=d0,
        marker="o",
        linestyle="",
        label="__nolegend__",
        color="k",
        markersize=5,
    )

    # Plot stacked barplot
    ax1 = ax.twinx()
    top_last = 0
    for strain, d1 in d0[_strain_order].T.iterrows():
        ax1.bar(
            x=d0.loc[d1.index].xpos,
            height=d1,
            width=1,
            bottom=top_last,
            alpha=1.0,
            color=strain_palette[strain],
            edgecolor="k",
            lw=1,
            label=strain,
        )
        top_last += d1
    ax1.set_yticks([])

    # Put strains behind points:
    ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
    ax.patch.set_visible(False)  # hide the 'canvas'
    ax1.patch.set_visible(True)  # show the 'canvas'

    # # Start and end of EEN
    # ax.axvline(0, lw=1, linestyle="--", color="k")
    # ax.axvline(
    #     trnsfm_x(subject.een_start_date_relative_een_end.loc[subject_id]),
    #     lw=1,
    #     linestyle="--",
    #     color="k",
    # )
    ax1.legend(bbox_to_anchor=(1, 1), ncols=2)

    ax.set_title(subject_id)

    # xtick_pos = np.array([-100, -50, -10, 0, 10, 50, 100, 300])
    # ax.set_xticks(trnsfm_x(xtick_pos))
    # ax.set_xticklabels(xtick_pos)
    ax.set_yscale("symlog", linthresh=ylinthresh)
    ax.set_ylim(0, 1.0)
    ax.set_xlim(-0.5, len(strain_frac_sample_list) - 0.5)
    lib.plot.rotate_xticklabels(ax=ax)

## Bacteroides uniformis (Zotu6 / 101346)

In [None]:
motu_id = "101346"
rotu_id = "Zotu6"
drop_strains_thresh = 0.5
ylinthresh = 1e-4

In [None]:
sf_fit = (
    sf.data.World.load(
        f"data/group/een/species/sp-{motu_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts48-s85-seed0.world.nc"
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={"CF_11": "CF_15", "CF_15": "CF_11"})
    .drop_low_abundance_strains(0.01)
    .rename_coords(strain=str)
)
mgtp_error = sf.evaluation.metagenotype_error2(sf_fit, discretized=False)[1]
entrp_error = sf.evaluation.metagenotype_entropy_error(
    sf_fit, discretized=False, p=1, montecarlo_draws=10
)[1]
comm_entrp = sf_fit.community.entropy().to_series()
high_mgtp_error = mgtp_error >= 0.1
high_entrp_error = entrp_error >= 0.2
high_comm_entrp = comm_entrp >= 1.5

# Genotype similarity ordered palette:
strain_linkage = sf_fit.genotype.linkage(optimal_ordering=True)
strain_order = list(
    linkage_order(
        strain_linkage,
        sf_fit.strain.values,
    )
)
strain_order.remove("-1")  # Drop "other" strain.
strain_palette = lib.plot.construct_ordered_palette(
    strain_order,
    cm="rainbow",
)

sf.plot.plot_community(
    sf_fit,
    scalex=0.4,
    scaley=0.6,
    row_linkage_func=lambda w: strain_linkage,
    row_colors=sf_fit.strain.to_series().map(strain_palette),
)

In [None]:
trnsfm_x = lambda x: np.sign(x) * np.sqrt(np.abs(x))
bar_width = 1.0

subject_list = subject_order  # [:3]

fig, axs = plt.subplots(
    nrows=len(subject_list),
    figsize=(10, 4 * len(subject_list)),
    squeeze=False,
    sharex=True,
    sharey=True,
)
for subject_id, ax in zip(subject_list, axs.flatten()):
    ax.set_title(subject_id)
    all_sample_list = (
        sample[lambda x: x.subject_id == subject_id].index.to_series().pipe(list)
    )
    subject_sample_order = (
        sample[lambda x: (x.subject_id == subject_id) & x.sample_type.isin(["human"])]
        .sort_values(["collection_date_relative_een_end"])
        .index.to_series()
        .pipe(list)
    )
    # Cull low abundance strains for each subject

    strain_frac_sample_list = list(set(all_sample_list) & set(sf_fit.sample.values))
    if len(strain_frac_sample_list) == 0:
        print(f"No strain analysis for {subject_id}.")
        comm = []
        _strain_order = []
    else:
        w = (
            sf_fit.sel(sample=strain_frac_sample_list)
            .drop_low_abundance_strains(drop_strains_thresh)
            .rename_coords(strain=str)
        )
        _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
        comm = w.community.to_pandas()

    d0 = (
        sample.loc[subject_sample_order, ["collection_date_relative_een_end"]]
        .assign(
            rotu_rabund=rotu_rabund[rotu_id],
            trnsfm_collection_date_relative_een_end=lambda x: x.collection_date_relative_een_end.apply(
                trnsfm_x
            ),
        )
        .join(comm)
    )

    ax.plot(
        "trnsfm_collection_date_relative_een_end",
        "rotu_rabund",
        data=d0,
        marker="o",
        label="__nolegend__",
        color="k",
        markersize=5,
    )

    # Plot stacked barplot
    ax1 = ax.twinx()
    top_last = 0
    for strain, d1 in d0[_strain_order].T.iterrows():
        ax1.bar(
            x=d0.loc[d1.index].trnsfm_collection_date_relative_een_end,
            height=d1,
            bottom=top_last,
            width=bar_width,
            alpha=1.0,
            color=strain_palette[strain],
            edgecolor="k",
            lw=1,
            label=strain,
        )
        top_last += d1
    ax1.set_yticks([])

    # Put strains behind points:
    ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
    ax.patch.set_visible(False)  # hide the 'canvas'
    ax1.patch.set_visible(True)  # show the 'canvas'

    # Start and end of EEN
    ax.axvline(0, lw=1, linestyle="--", color="k")
    ax.axvline(
        trnsfm_x(subject.een_start_date_relative_een_end.loc[subject_id]),
        lw=1,
        linestyle="--",
        color="k",
    )
    ax1.legend(bbox_to_anchor=(1, 1), ncols=1)
    xtick_pos = np.array([-100, -50, -10, 0, 10, 50, 100, 300])
    ax.set_xticks(trnsfm_x(xtick_pos))
    ax.set_xticklabels(xtick_pos)
    ax.set_yscale("symlog", linthresh=ylinthresh)
    ax.set_ylim(0, 1.0)
    ax1.set_ylim(0, 1.0)
    lib.plot.rotate_xticklabels(ax=ax)

In [None]:
trnsfm_x = lambda x: np.sign(x) * np.sqrt(np.abs(x))
bar_width = 1.0

subject_list = subject_order[:3]

fig, axs = plt.subplots(
    nrows=len(subject_list),
    figsize=(10, 4 * len(subject_list)),
    squeeze=False,
    sharex=True,
    sharey=True,
)
for subject_id, ax in zip(subject_list, axs.flatten()):
    ax.set_title(subject_id)
    all_sample_list = (
        sample[lambda x: x.subject_id == subject_id].index.to_series().pipe(list)
    )
    subject_sample_order = (
        sample[lambda x: (x.subject_id == subject_id) & x.sample_type.isin(["human"])]
        .sort_values(["collection_date_relative_een_end"])
        .index.to_series()
        .pipe(list)
    )
    # Cull low abundance strains for each subject

    strain_frac_sample_list = list(set(all_sample_list) & set(sf_fit.sample.values))
    if len(strain_frac_sample_list) == 0:
        print(f"No strain analysis for {subject_id}.")
        comm = []
        _strain_order = []
    else:
        w = (
            sf_fit.sel(sample=strain_frac_sample_list)
            .drop_low_abundance_strains(drop_strains_thresh)
            .rename_coords(strain=str)
        )
        _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
        comm = w.community.to_pandas()

    d0 = (
        sample.loc[subject_sample_order, ["collection_date_relative_een_end"]]
        .assign(
            rotu_rabund=rotu_rabund[rotu_id],
            trnsfm_collection_date_relative_een_end=lambda x: x.collection_date_relative_een_end.apply(
                trnsfm_x
            ),
        )
        .join(comm)
    )

    ax.plot(
        "trnsfm_collection_date_relative_een_end",
        "rotu_rabund",
        data=d0,
        marker="o",
        label="__nolegend__",
        color="k",
        markersize=5,
    )

    # Plot stacked barplot
    ax1 = ax.twinx()
    top_last = 0
    for strain, d1 in d0[_strain_order].T.iterrows():
        ax1.bar(
            x=d0.loc[d1.index].trnsfm_collection_date_relative_een_end,
            height=d1,
            bottom=top_last,
            width=bar_width,
            alpha=1.0,
            color=strain_palette[strain],
            edgecolor="k",
            lw=1,
            label=strain,
        )
        top_last += d1
    ax1.set_yticks([])

    # Put strains behind points:
    ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
    ax.patch.set_visible(False)  # hide the 'canvas'
    ax1.patch.set_visible(True)  # show the 'canvas'

    # Start and end of EEN
    ax.axvline(0, lw=1, linestyle="--", color="k")
    ax.axvline(
        trnsfm_x(subject.een_start_date_relative_een_end.loc[subject_id]),
        lw=1,
        linestyle="--",
        color="k",
    )
    ax1.legend(bbox_to_anchor=(1, 1), ncols=1)
    xtick_pos = np.array([-100, -50, -10, 0, 10, 50, 100, 300])
    ax.set_xticks(trnsfm_x(xtick_pos))
    ax.set_xticklabels(xtick_pos)
    ax.set_yscale("symlog", linthresh=ylinthresh)
    ax.set_ylim(0, 1.0)
    ax1.set_ylim(0, 1.0)
    lib.plot.rotate_xticklabels(ax=ax)

In [None]:
d = (
    rotu_rabund.apply(lambda x: x + x.replace({0: np.inf}).min())
    .join(sample)
    .groupby(["subject_id", "diet_or_media"])[rotu_id]
    .mean()
    .unstack()[["EEN", "PostEEN"]]
    .dropna()
    .assign(log2_ratio=lambda x: np.log2(x["PostEEN"] / x["EEN"]))
)

fig, ax = plt.subplots()
print(d.log2_ratio.mean())
print(sp.stats.wilcoxon(d["PostEEN"], d["EEN"]))
ax.hist(d.log2_ratio, bins=20)

fig, ax = plt.subplots(figsize=(3, 8))
max_c_value = np.abs(d.log2_ratio).max()
for subject_id, (een_rabund, post_rabund, log2_ratio, c) in d.assign(
    c=lambda x: ((x.log2_ratio / max_c_value) + 1) / 2
).iterrows():
    ax.plot([0, 1], [een_rabund, post_rabund], c=mpl.cm.coolwarm(c), lw=4)
ax.set_yscale("log")
ax.set_xticks([0, 1])
ax.set_xticklabels(["EEN", "PostEEN"])
ax.set_xlim(-0.1, 1.1)

In [None]:
for subject_id, sample_type in product(["A", "B", "H"], ["Fermenter", "mouse"]):
    all_sample_list = (
        sample[lambda x: (x.sample_type == sample_type) & (x.subject_id == subject_id)]
        .index.to_series()
        .pipe(list)
    )
    subject_sample_order = (
        sample[
            lambda x: (x.sample_type == sample_type)
            & (x.subject_id == subject_id)
            & (x.index.isin(sf_fit.sample.values))
        ]
        .sort_values(
            [
                "sample_type",
                "diet_or_media",
                "mouse_genotype",
                "source_samples",
                "status_mouse_inflamed",
            ]
        )
        .index.to_series()
        .pipe(list)
    )
    # Cull low abundance strains for each subject

    strain_frac_sample_list = list(set(all_sample_list) & set(sf_fit.sample.values))
    if len(strain_frac_sample_list) == 0:
        print(f"No strain analysis for {subject_id}.")
        comm = []
        _strain_order = []
    else:
        w = (
            sf_fit.sel(sample=strain_frac_sample_list)
            .drop_low_abundance_strains(drop_strains_thresh)
            .rename_coords(strain=str)
        )
        _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
        comm = w.community.to_pandas()

    d0 = (
        sample.loc[
            subject_sample_order,
            [
                "sample_type",
                "diet_or_media",
                "mouse_genotype",
                "source_samples",
                "status_mouse_inflamed",
            ],
        ]
        .assign(
            rotu_rabund=rotu_rabund[rotu_id],
        )
        .join(comm)
        .assign(xpos=lambda x: np.arange(len(x.index)))
    )

    fig, ax = plt.subplots(
        figsize=(0.4 * len(strain_frac_sample_list), 4),
    )

    ax.plot(
        "rotu_rabund",
        data=d0,
        marker="o",
        linestyle="",
        label="__nolegend__",
        color="k",
        markersize=5,
    )

    # Plot stacked barplot
    ax1 = ax.twinx()
    top_last = 0
    for strain, d1 in d0[_strain_order].T.iterrows():
        ax1.bar(
            x=d0.loc[d1.index].xpos,
            height=d1,
            width=1,
            bottom=top_last,
            alpha=1.0,
            color=strain_palette[strain],
            edgecolor="k",
            lw=1,
            label=strain,
        )
        top_last += d1
    ax1.set_yticks([])

    # Put strains behind points:
    ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
    ax.patch.set_visible(False)  # hide the 'canvas'
    ax1.patch.set_visible(True)  # show the 'canvas'

    # # Start and end of EEN
    # ax.axvline(0, lw=1, linestyle="--", color="k")
    # ax.axvline(
    #     trnsfm_x(subject.een_start_date_relative_een_end.loc[subject_id]),
    #     lw=1,
    #     linestyle="--",
    #     color="k",
    # )
    ax1.legend(bbox_to_anchor=(1, 1), ncols=2)

    ax.set_title(subject_id)

    # xtick_pos = np.array([-100, -50, -10, 0, 10, 50, 100, 300])
    # ax.set_xticks(trnsfm_x(xtick_pos))
    # ax.set_xticklabels(xtick_pos)
    ax.set_yscale("symlog", linthresh=ylinthresh)
    ax.set_ylim(0, 1.0)
    ax.set_xlim(-0.5, len(strain_frac_sample_list) - 0.5)
    lib.plot.rotate_xticklabels(ax=ax)

## s__Bacteroides fragilis (Zotu12 / 101337)

In [None]:
motu_id = "101337"
rotu_id = "Zotu12"
drop_strains_thresh = 0.5
ylinthresh = 1e-4

In [None]:
sf_fit = (
    sf.data.World.load(
        f"data/group/een/species/sp-{motu_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts48-s85-seed0.world.nc"
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={"CF_11": "CF_15", "CF_15": "CF_11"})
    .drop_low_abundance_strains(0.01)
    .rename_coords(strain=str)
)
mgtp_error = sf.evaluation.metagenotype_error2(sf_fit, discretized=False)[1]
entrp_error = sf.evaluation.metagenotype_entropy_error(
    sf_fit, discretized=False, p=1, montecarlo_draws=10
)[1]
comm_entrp = sf_fit.community.entropy().to_series()
high_mgtp_error = mgtp_error >= 0.1
high_entrp_error = entrp_error >= 0.2
high_comm_entrp = comm_entrp >= 1.5

# Genotype similarity ordered palette:
strain_linkage = sf_fit.genotype.linkage(optimal_ordering=True)
strain_order = list(
    linkage_order(
        strain_linkage,
        sf_fit.strain.values,
    )
)
strain_order.remove("-1")  # Drop "other" strain.
strain_palette = lib.plot.construct_ordered_palette(
    strain_order,
    cm="rainbow",
)

sf.plot.plot_community(
    sf_fit,
    scalex=0.4,
    scaley=0.6,
    row_linkage_func=lambda w: strain_linkage,
    row_colors=sf_fit.strain.to_series().map(strain_palette),
)

In [None]:
trnsfm_x = lambda x: np.sign(x) * np.sqrt(np.abs(x))
bar_width = 1.0

subject_list = subject_order  # [:3]

fig, axs = plt.subplots(
    nrows=len(subject_list),
    figsize=(10, 4 * len(subject_list)),
    squeeze=False,
    sharex=True,
    sharey=True,
)
for subject_id, ax in zip(subject_list, axs.flatten()):
    ax.set_title(subject_id)
    all_sample_list = (
        sample[lambda x: x.subject_id == subject_id].index.to_series().pipe(list)
    )
    subject_sample_order = (
        sample[lambda x: (x.subject_id == subject_id) & x.sample_type.isin(["human"])]
        .sort_values(["collection_date_relative_een_end"])
        .index.to_series()
        .pipe(list)
    )
    # Cull low abundance strains for each subject

    strain_frac_sample_list = list(set(all_sample_list) & set(sf_fit.sample.values))
    if len(strain_frac_sample_list) == 0:
        print(f"No strain analysis for {subject_id}.")
        comm = []
        _strain_order = []
    else:
        w = (
            sf_fit.sel(sample=strain_frac_sample_list)
            .drop_low_abundance_strains(drop_strains_thresh)
            .rename_coords(strain=str)
        )
        _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
        comm = w.community.to_pandas()

    d0 = (
        sample.loc[subject_sample_order, ["collection_date_relative_een_end"]]
        .assign(
            rotu_rabund=rotu_rabund[rotu_id],
            trnsfm_collection_date_relative_een_end=lambda x: x.collection_date_relative_een_end.apply(
                trnsfm_x
            ),
        )
        .join(comm)
    )

    ax.plot(
        "trnsfm_collection_date_relative_een_end",
        "rotu_rabund",
        data=d0,
        marker="o",
        label="__nolegend__",
        color="k",
        markersize=5,
    )

    # Plot stacked barplot
    ax1 = ax.twinx()
    top_last = 0
    for strain, d1 in d0[_strain_order].T.iterrows():
        ax1.bar(
            x=d0.loc[d1.index].trnsfm_collection_date_relative_een_end,
            height=d1,
            bottom=top_last,
            width=bar_width,
            alpha=1.0,
            color=strain_palette[strain],
            edgecolor="k",
            lw=1,
            label=strain,
        )
        top_last += d1
    ax1.set_yticks([])

    # Put strains behind points:
    ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
    ax.patch.set_visible(False)  # hide the 'canvas'
    ax1.patch.set_visible(True)  # show the 'canvas'

    # Start and end of EEN
    ax.axvline(0, lw=1, linestyle="--", color="k")
    ax.axvline(
        trnsfm_x(subject.een_start_date_relative_een_end.loc[subject_id]),
        lw=1,
        linestyle="--",
        color="k",
    )
    ax1.legend(bbox_to_anchor=(1, 1), ncols=1)
    xtick_pos = np.array([-100, -50, -10, 0, 10, 50, 100, 300])
    ax.set_xticks(trnsfm_x(xtick_pos))
    ax.set_xticklabels(xtick_pos)
    ax.set_yscale("symlog", linthresh=ylinthresh)
    ax.set_ylim(0, 1.0)
    ax1.set_ylim(0, 1.0)
    lib.plot.rotate_xticklabels(ax=ax)

In [None]:
d = (
    rotu_rabund.apply(lambda x: x + x.replace({0: np.inf}).min())
    .join(sample)
    .groupby(["subject_id", "diet_or_media"])[rotu_id]
    .mean()
    .unstack()[["EEN", "PostEEN"]]
    .dropna()
    .assign(log2_ratio=lambda x: np.log2(x["PostEEN"] / x["EEN"]))
)

fig, ax = plt.subplots()
print(d.log2_ratio.mean())
print(sp.stats.wilcoxon(d["PostEEN"], d["EEN"]))
ax.hist(d.log2_ratio, bins=20)

fig, ax = plt.subplots(figsize=(3, 8))
max_c_value = np.abs(d.log2_ratio).max()
for subject_id, (een_rabund, post_rabund, log2_ratio, c) in d.assign(
    c=lambda x: ((x.log2_ratio / max_c_value) + 1) / 2
).iterrows():
    ax.plot([0, 1], [een_rabund, post_rabund], c=mpl.cm.coolwarm(c), lw=4)
ax.set_yscale("log")
ax.set_xticks([0, 1])
ax.set_xticklabels(["EEN", "PostEEN"])
ax.set_xlim(-0.1, 1.1)

In [None]:
for subject_id, sample_type in product(["A", "B", "H"], ["Fermenter", "mouse"]):
    all_sample_list = (
        sample[lambda x: (x.sample_type == sample_type) & (x.subject_id == subject_id)]
        .index.to_series()
        .pipe(list)
    )
    subject_sample_order = (
        sample[
            lambda x: (x.sample_type == sample_type)
            & (x.subject_id == subject_id)
            & (x.index.isin(sf_fit.sample.values))
        ]
        .sort_values(
            [
                "sample_type",
                "diet_or_media",
                "mouse_genotype",
                "source_samples",
                "status_mouse_inflamed",
            ]
        )
        .index.to_series()
        .pipe(list)
    )
    # Cull low abundance strains for each subject

    strain_frac_sample_list = list(set(all_sample_list) & set(sf_fit.sample.values))
    if len(strain_frac_sample_list) == 0:
        print(f"No strain analysis for {subject_id}.")
        comm = []
        _strain_order = []
    else:
        w = (
            sf_fit.sel(sample=strain_frac_sample_list)
            .drop_low_abundance_strains(drop_strains_thresh)
            .rename_coords(strain=str)
        )
        _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
        comm = w.community.to_pandas()

    d0 = (
        sample.loc[
            subject_sample_order,
            [
                "sample_type",
                "diet_or_media",
                "mouse_genotype",
                "source_samples",
                "status_mouse_inflamed",
            ],
        ]
        .assign(
            rotu_rabund=rotu_rabund[rotu_id],
        )
        .join(comm)
        .assign(xpos=lambda x: np.arange(len(x.index)))
    )

    fig, ax = plt.subplots(
        figsize=(0.4 * len(strain_frac_sample_list), 4),
    )

    ax.plot(
        "rotu_rabund",
        data=d0,
        marker="o",
        linestyle="",
        label="__nolegend__",
        color="k",
        markersize=5,
    )

    # Plot stacked barplot
    ax1 = ax.twinx()
    top_last = 0
    for strain, d1 in d0[_strain_order].T.iterrows():
        ax1.bar(
            x=d0.loc[d1.index].xpos,
            height=d1,
            width=1,
            bottom=top_last,
            alpha=1.0,
            color=strain_palette[strain],
            edgecolor="k",
            lw=1,
            label=strain,
        )
        top_last += d1
    ax1.set_yticks([])

    # Put strains behind points:
    ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
    ax.patch.set_visible(False)  # hide the 'canvas'
    ax1.patch.set_visible(True)  # show the 'canvas'

    # # Start and end of EEN
    # ax.axvline(0, lw=1, linestyle="--", color="k")
    # ax.axvline(
    #     trnsfm_x(subject.een_start_date_relative_een_end.loc[subject_id]),
    #     lw=1,
    #     linestyle="--",
    #     color="k",
    # )
    ax1.legend(bbox_to_anchor=(1, 1), ncols=2)

    ax.set_title(subject_id)

    # xtick_pos = np.array([-100, -50, -10, 0, 10, 50, 100, 300])
    # ax.set_xticks(trnsfm_x(xtick_pos))
    # ax.set_xticklabels(xtick_pos)
    ax.set_yscale("symlog", linthresh=ylinthresh)
    ax.set_ylim(0, 1.0)
    ax.set_xlim(-0.5, len(strain_frac_sample_list) - 0.5)
    lib.plot.rotate_xticklabels(ax=ax)

## Plot complex experimental design

In [None]:
subject_transfer_sample_lists = {
    # B – PreEEN (active disease)
    # CF-7 with direct transfer CF 379,380,381,384,385,386
    # CF7 with ex vivo samples CF 97,98,99,100 (first FR, then EEN)
    # Ex vivo with post ex vivo transfer
    # CF 97,98 with 397,406,408,409,395,402
    # CF 99,100 with 430,426,427,428,429,431
    "B": {
        "Subject B:\nTime Series": (0, [
            "CF_7",
            "CF_8",
            "CF_9",
            "CF_10",
            "CF_11",
            "CF_12",
            "CF_13",
            "CF_14",
        ]),
        "Direct Transfer": (1, [
            "CF_7",
            "CF_379",
            "CF_380",
            "CF_381",
            "CF_384",
            "CF_385",
            "CF_386",
        ]),
        "Ex Vivo": (1, [
            "CF_7",
            "CF_97",
            "CF_98",
            "CF_99",
            "CF_100",
        ]),
        "Post Ex Vivo Transfer #1": (2, [
            "CF_97",
            "CF_98",
            "CF_397",
            "CF_406",
            "CF_408",
            "CF_409",
            "CF_395",
            "CF_402",
        ]),
        "Post Ex Vivo Transfer #2": (2, [
            "CF_99",
            "CF_100",
            "CF_430",
            "CF_426",
            "CF_427",
            "CF_428",
            "CF_429",
            "CF_431",
        ]),
    },
    # A – EEN (remission)
    # CF-3 with direct transfer CF 140,141,142, 149,150,151
    # CF-3 with ex vivo samples CF_103,104, 101,102 (started with EEN, then FR)
    # Ex vivo with post ex vivo transfer
    # cf 103,104 with 107-175
    # cf 101,102 with 152-157
    "A": {
        "Subject A:\nTime Series": (0, [
            "CF_1",
            "CF_2",
            "CF_3",
            "CF_6",
        ]),
        "Direct Transfer": (1, [
            "CF_3",
            "CF_140",
            "CF_141",
            "CF_142",
            "CF_149",
            "CF_150",
            "CF_151",
        ]),
        "Ex Vivo": (1, [
            "CF_3",
            "CF_103",
            "CF_104",
            "CF_101",
            "CF_102",
        ]),
        "Post Ex Vivo Transfer #1": (2, [
            "CF_103",
            "CF_104",
            "CF_170",
            "CF_171",
            "CF_172",
            "CF_173",
            "CF_174",
            "CF_175",
        ]),
        "Post Ex Vivo Transfer #2": (2, [
            "CF_101",
            "CF_102",
            "CF_152",
            "CF_153",
            "CF_154",
            "CF_155",
            "CF_156",
            "CF_157",
        ]),
    },
    # H – EEN (persistent inflammation)
    # CF48 with direct transfer CF 115-120
    # CF48 with ex vivo samples 107,108 (109,110) (first EEN, then FR (no FR transfer in mice)
    # Ex vivo with post ex vivo transfer
    # 107,108 with 127-133 (chow) and 667-672 (EEN-like mouse diet; PD)
    "H": {
        "Subject H:\nTime Series": (0, [
            "CF_46",
            "CF_47",
            "CF_48",
            "CF_49",
            "CF_50",
            "CF_51",
        ]),
        "Direct Transfer": (1, [
            "CF_48",
            "CF_115",
            "CF_116",
            "CF_117",
            "CF_118",
            "CF_119",
            "CF_120",
        ]),
        "Ex Vivo": (1, [
            "CF_48",
            "CF_107",
            "CF_108",
            "CF_109",
            "CF_110",
        ]),
        "Post Ex Vivo - Chow Diet": (2, [
            "CF_107",
            "CF_108",
            "CF_127",
            "CF_128",
            # "CF_129",  # ???  # In Debbie's email, but seems to be wrong.
            "CF_130",
            "CF_131",
            "CF_132",
            "CF_133",
        ]),
        "Post Ex Vivo - EEN-Like Diet": (2, [
            "CF_107",
            "CF_108",
            "CF_667",
            "CF_668",
            "CF_669",
            "CF_670",
            "CF_671",
            "CF_672",
        ]),
    },
}

### Plots

In [None]:
species = "101493"
subject_order = [
    "A",
    "B",
    "H",
    "C",
    "D",
    "E",
    "F",
    "G",
    "K",
    "L",
    "N",
    "O",
    "Q",
    "R",
    "S",
    "T",
    "U",
    "M",
    "P",
]

print(motu_taxonomy.loc[species])


In [None]:
strain_fit = (
    sf.data.World.load(
        f"data/group/een/species/sp-{species}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts48-s85-seed0.world.nc"
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={"CF_11": "CF_15", "CF_15": "CF_11"})
    .drop_low_abundance_strains(0.2)
    .rename_coords(strain=str)
)
print(strain_fit.sizes)

# Genotype similarity ordered palette:
strain_linkage = strain_fit.genotype.linkage(optimal_ordering=True)
strain_order = list(
    linkage_order(
        strain_linkage,
        strain_fit.strain.values,
    )
)
strain_order.remove("-1")  # Drop "other" strain.
strain_order.append("-1")  # Add to end of list
strain_palette = lib.plot.construct_ordered_palette(
    strain_order,
    cm="rainbow",
    extend={"-1": "lightgrey"},
)

sf.evaluation.metagenotype_error2(strain_fit, discretized=True)[0]

In [None]:
np.random.seed(0)

sample_linkage = strain_fit.unifrac_linkage(optimal_ordering=True)
position_ss = strain_fit.random_sample(position=min(strain_fit.sizes['position'], 1000)).position

sf.plot.plot_metagenotype(strain_fit.sel(position=position_ss), col_linkage_func=lambda w: sample_linkage)
sf.plot.plot_community(strain_fit, col_linkage_func=lambda w: sample_linkage)

In [None]:
ncols = 4

fig, axs = lib.plot.subplots_grid(
    ncols=ncols,
    naxes=len(subject_order),
    ax_width=4,
    gridspec_kw=dict(hspace=1.2, wspace=0.0),
    sharey=True,
)

for subject, ax in zip(subject_order, axs.flatten()):
    subject_sample_order = sample.sort_values("collection_date_relative_een_end")[
        lambda x: (x.subject_id == subject)
        & (x.sample_type == "human")
        & (x.index.isin(motu_rabund.index))
    ].index
    try:
        subject_comm_sample_list = list(set(subject_sample_order) & set(strain_fit.sample.values))
        subject_comm = (
            strain_fit.sel(sample=subject_comm_sample_list)
            .drop_low_abundance_strains(0.2, agg_strain_coord="-1")
            .community.to_pandas()
        )
    except ValueError:
        subject_comm = pd.DataFrame([], columns=["-1"])

    ax.set_title(subject)
    d = (
        sample.reindex(subject_sample_order)
        # .dropna(subset=["collection_date_relative_een_end"])
        # .sort_values("collection_date_relative_een_end")
        .assign(
            t=lambda x: range(len(x)),
        )
    ).join(subject_comm)

    plot_stacked_barplot(
        data=d,
        x_var="t",
        order=[s for s in strain_order if s in subject_comm.columns],
        palette=strain_palette,
        ax=ax,
    )
    # ax.legend(bbox_to_anchor=(1, 1))
    ax.set_xticklabels(d.timepoint)
    ax.set_aspect(3, anchor="NW")
    lib.plot.rotate_xticklabels(rotation=45, ax=ax)
    ax.set_yticks([0, 0.5, 1.0])
    ax.set_xlim(-0.5, d.shape[0] - 0.5)

fig.savefig(f"fig/{species}.een_strain_time_series_plot.pdf")

In [None]:
fig, axs = plt.subplots(
    3,
    5,
    figsize=(5 * 5, 3 * 3.5),
    squeeze=False,
    sharey=True,
    gridspec_kw=dict(hspace=1.5, wspace=0),
)

for subject, axs_row in zip(subject_transfer_sample_lists, axs):
    subject_comm_sample_list = list(
        set(idxwhere(sample.subject_id == subject)) & set(strain_fit.sample.values)
    )

    try:
        subject_comm = (
            strain_fit.sel(sample=subject_comm_sample_list)
            .drop_low_abundance_strains(0.2, agg_strain_coord="-1")
            .community.to_pandas()
        )
    except ValueError:
        subject_comm = pd.DataFrame([], columns=["-1"])

    for (sample_list_label, (num_offset_samples, sample_list)), ax in zip(
        subject_transfer_sample_lists[subject].items(), axs_row
    ):
        d = (
            sample.reindex(sample_list)
            # .dropna(subset=["collection_date_relative_een_end"])
            # .sort_values("collection_date_relative_een_end")
            .assign(
                t=lambda x: range(len(x)),
                simple_label=lambda x: np.where(
                    x.sample_type == "human",
                    x.timepoint,
                    np.where(
                        x.sample_type == "Fermenter",
                        x.diet_or_media,
                        np.where(
                            x.status_mouse_inflamed == "Inflamed",
                            x.diet_or_media + " / inflam",
                            x.diet_or_media + " / not",
                        ),
                    ),
                ),
            )
        ).join(subject_comm)
        d.loc[d.index[:num_offset_samples], 't'] -= 0.2  # Offset width
        

        plot_stacked_barplot(
            data=d,
            x_var="t",
            order=[s for s in strain_order if s in subject_comm.columns],
            palette=strain_palette,
            ax=ax,
            width=0.8,
            lw=0.5,
        )
        
        ax.set_title(sample_list_label)
        ax.set_xticklabels(d.simple_label)
        ax.set_aspect(4, anchor="NW")
        lib.plot.rotate_xticklabels(rotation=45, ax=ax)
        ax.set_yticks([0, 0.5, 1.0])
        ax.set_xlim(d.t.min() - 0.5, d.t.max() + 0.5)
    ax.legend(bbox_to_anchor=(1, 1), ncols=2)

fig.savefig(f"fig/{species}.een_strain_transfer_expt_plot.pdf")

In [None]:
fig, axs = plt.subplots(
    3,
    5,
    figsize=(5 * 5, 3 * 3.5),
    squeeze=False,
    sharey=True,
    gridspec_kw=dict(hspace=1.5, wspace=0),
)

for subject, axs_row in zip(subject_transfer_sample_lists, axs):
    subject_comm_sample_list = list(
        set(idxwhere(sample.subject_id == subject)) & set(strain_fit.sample.values)
    )

    try:
        subject_comm = (
            strain_fit.sel(sample=subject_comm_sample_list)
            .drop_low_abundance_strains(0.2, agg_strain_coord="-1")
            .community.to_pandas()
        )
    except ValueError:
        subject_comm = pd.DataFrame([], columns=["-1"])

    for (sample_list_label, (num_offset_samples, sample_list)), ax in zip(
        subject_transfer_sample_lists[subject].items(), axs_row
    ):
        d = (
            sample.reindex(sample_list)
            # .dropna(subset=["collection_date_relative_een_end"])
            # .sort_values("collection_date_relative_een_end")
            .assign(
                t=lambda x: range(len(x)),
                simple_label=lambda x: np.where(
                    x.sample_type == "human",
                    x.timepoint,
                    np.where(
                        x.sample_type == "Fermenter",
                        x.diet_or_media,
                        np.where(
                            x.status_mouse_inflamed == "Inflamed",
                            x.diet_or_media + " / inflam",
                            x.diet_or_media + " / not",
                        ),
                    ),
                ),
            )
        ).join(subject_comm)
        d.loc[d.index[:num_offset_samples], 't'] -= 0.2  # Offset width
        

        plot_stacked_barplot(
            data=d,
            x_var="t",
            order=[s for s in strain_order if s in subject_comm.columns],
            palette=strain_palette,
            ax=ax,
            width=0.8,
            lw=0.5,
        )
        
        ax.set_title(sample_list_label)
        ax.set_xticklabels(d.index)
        ax.set_aspect(4, anchor="NW")
        lib.plot.rotate_xticklabels(rotation=45, ax=ax)
        ax.set_yticks([0, 0.5, 1.0])
        ax.set_xlim(d.t.min() - 0.5, d.t.max() + 0.5)
    ax.legend(bbox_to_anchor=(1, 1), ncols=2)

fig.savefig(f"fig/{species}.een_strain_transfer_expt_plot.sample_id.pdf")

In [None]:
sf.plot.plot_metagenotype_frequency_spectrum_compare_samples(strain_fit, ['CF_6', 'CF_3'])

In [None]:
fig, axs = plt.subplots(
    3,
    5,
    figsize=(5 * 5, 3 * 3.5),
    squeeze=False,
    sharey=False,
    gridspec_kw=dict(hspace=1.5, wspace=0),
)

for subject, axs_row in zip(subject_transfer_sample_lists, axs):
    subject_comm_sample_list = list(
        set(idxwhere(sample.subject_id == subject)) & set(strain_fit.sample.values)
    )

    try:
        subject_comm = (
            strain_fit.sel(sample=subject_comm_sample_list)
            .drop_low_abundance_strains(0.2, agg_strain_coord="-1")
            .community.to_pandas()
        )
    except ValueError:
        subject_comm = pd.DataFrame([], columns=["-1"])

    for (sample_list_label, (num_offset_samples, sample_list)), ax in zip(
        subject_transfer_sample_lists[subject].items(), axs_row
    ):
        d = (
            sample.reindex(sample_list)
            # .dropna(subset=["collection_date_relative_een_end"])
            # .sort_values("collection_date_relative_een_end")
            .assign(
                t=lambda x: range(len(x)),
                simple_label=lambda x: np.where(
                    x.sample_type == "human",
                    x.timepoint,
                    np.where(
                        x.sample_type == "Fermenter",
                        x.diet_or_media,
                        np.where(
                            x.status_mouse_inflamed == "Inflamed",
                            x.diet_or_media + " / inflam",
                            x.diet_or_media + " / not",
                        ),
                    ),
                ),
            )
        ).join(subject_comm)
        # d.loc[d.index[:num_offset_samples], 't'] -= 0.7  # Offset width

        plot_stacked_barplot(
            data=d,
            x_var="t",
            order=[s for s in strain_order if s in subject_comm.columns],
            palette=strain_palette,
            ax=ax,
            width=0.8,
            lw=0,
        )

        # Mark inflammation
        for (_, x) in d.iterrows():
            if x.status_mouse_inflamed == 'Inflamed':
                ax.annotate('*', xy=(x.t, 1.0))

        ax.set_title(sample_list_label, pad=25)
        ax.set_xticklabels(d.simple_label)
        ax.set_aspect(9, anchor="NW")
        ax.set_ylim(0, 1.0)
        lib.plot.rotate_xticklabels(rotation=45, ax=ax)
        ax.set_yticks(np.linspace(0, 1.0, num=6))
        ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, symbol=''))
        ax.set_xlim(d.t.min() - 0.5, d.t.max() + 0.5)
        ax.spines[['right', 'top']].set_visible(False)
    ax.legend(bbox_to_anchor=(1, 1), ncols=2)
            
fig.savefig(f"fig/{species}.een_strain_transfer_expt_plot.paper_style.pdf")

In [None]:
sf.plot.plot_genotype(strain_fit.random_sample(position=1000), row_linkage_func=lambda w: strain_linkage, scaley=0.4)

In [None]:
strain_subject_coabundance_linkage = sp.cluster.hierarchy.linkage(strain_fit.community.to_pandas().groupby(sample.subject_id).mean().T, metric='cosine', optimal_ordering=True)
shuffled_strain_order = maximally_shuffled_order(linkage_order(strain_subject_coabundance_linkage, labels=strain_fit.strain.values))
shuffled_strain_order.remove("-1")  # Drop "other" strain.
shuffled_strain_order.append("-1")  # Add to end of list
shuffled_strain_palette = lib.plot.construct_ordered_palette(
    shuffled_strain_order,
    cm="rainbow",
    extend={"-1": "lightgrey"},
)

fig, axs = plt.subplots(
    3,
    5,
    figsize=(5 * 5, 3 * 3.5),
    squeeze=False,
    sharey=False,
    gridspec_kw=dict(hspace=1.5, wspace=0),
)

for subject, axs_row in zip(subject_transfer_sample_lists, axs):
    subject_comm_sample_list = list(
        set(idxwhere(sample.subject_id == subject)) & set(strain_fit.sample.values)
    )

    try:
        subject_comm = (
            strain_fit.sel(sample=subject_comm_sample_list)
            .drop_low_abundance_strains(0.2, agg_strain_coord="-1")
            .community.to_pandas()
        )
    except ValueError:
        subject_comm = pd.DataFrame([], columns=["-1"])

    for (sample_list_label, (num_offset_samples, sample_list)), ax in zip(
        subject_transfer_sample_lists[subject].items(), axs_row
    ):
        d = (
            sample.reindex(sample_list)
            # .dropna(subset=["collection_date_relative_een_end"])
            # .sort_values("collection_date_relative_een_end")
            .assign(
                t=lambda x: range(len(x)),
                simple_label=lambda x: np.where(
                    x.sample_type == "human",
                    x.timepoint,
                    np.where(
                        x.sample_type == "Fermenter",
                        x.diet_or_media,
                        np.where(
                            x.status_mouse_inflamed == "Inflamed",
                            x.diet_or_media + " / inflam",
                            x.diet_or_media + " / not",
                        ),
                    ),
                ),
            )
        ).join(subject_comm)
        # d.loc[d.index[:num_offset_samples], 't'] -= 0.7  # Offset width

        plot_stacked_barplot(
            data=d,
            x_var="t",
            order=[s for s in shuffled_strain_order if s in subject_comm.columns],
            palette=shuffled_strain_palette,
            ax=ax,
            width=0.8,
            lw=0,
        )

        # Mark inflammation
        for (_, x) in d.iterrows():
            if x.status_mouse_inflamed == 'Inflamed':
                ax.annotate('*', xy=(x.t, 1.0))

        ax.set_title(sample_list_label, pad=25)
        ax.set_xticklabels(d.simple_label)
        ax.set_aspect(9, anchor="NW")
        ax.set_ylim(0, 1.0)
        lib.plot.rotate_xticklabels(rotation=45, ax=ax)
        ax.set_yticks(np.linspace(0, 1.0, num=6))
        ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, symbol=''))
        ax.set_xlim(d.t.min() - 0.5, d.t.max() + 0.5)
        ax.spines[['right', 'top']].set_visible(False)
    ax.legend(bbox_to_anchor=(1, 1), ncols=2)
            
fig.savefig(f"fig/{species}.een_strain_transfer_expt_plot.paper_style.pdf")

In [None]:
%autoreload

In [None]:
more_colors_strain_palette = lib.plot.construct_ordered_palette(
    strain_order,
    cm="rainbow",
    extend={"-1": "lightgrey"},
    desaturate_levels=[1.0, 0.7, 0.4],
)

fig, axs = plt.subplots(
    3,
    5,
    figsize=(5 * 5, 3 * 3.5),
    squeeze=False,
    sharey=False,
    gridspec_kw=dict(hspace=1.5, wspace=0),
)

for subject, axs_row in zip(subject_transfer_sample_lists, axs):
    subject_comm_sample_list = list(
        set(idxwhere(sample.subject_id == subject)) & set(strain_fit.sample.values)
    )

    try:
        subject_comm = (
            strain_fit.sel(sample=subject_comm_sample_list)
            .drop_low_abundance_strains(0.2, agg_strain_coord="-1")
            .community.to_pandas()
        )
    except ValueError:
        subject_comm = pd.DataFrame([], columns=["-1"])

    for (sample_list_label, (num_offset_samples, sample_list)), ax in zip(
        subject_transfer_sample_lists[subject].items(), axs_row
    ):
        d = (
            sample.reindex(sample_list)
            # .dropna(subset=["collection_date_relative_een_end"])
            # .sort_values("collection_date_relative_een_end")
            .assign(
                t=lambda x: range(len(x)),
                simple_label=lambda x: np.where(
                    x.sample_type == "human",
                    x.timepoint,
                    np.where(
                        x.sample_type == "Fermenter",
                        x.diet_or_media,
                        np.where(
                            x.status_mouse_inflamed == "Inflamed",
                            x.diet_or_media + " / inflam",
                            x.diet_or_media + " / not",
                        ),
                    ),
                ),
            )
        ).join(subject_comm)
        # d.loc[d.index[:num_offset_samples], 't'] -= 0.7  # Offset width

        plot_stacked_barplot(
            data=d,
            x_var="t",
            order=[s for s in strain_order if s in subject_comm.columns],
            palette=more_colors_strain_palette,
            ax=ax,
            width=0.8,
            lw=0,
        )

        # Mark inflammation
        for (_, x) in d.iterrows():
            if x.status_mouse_inflamed == 'Inflamed':
                ax.annotate('*', xy=(x.t, 1.0))

        ax.set_title(sample_list_label, pad=25)
        ax.set_xticklabels(d.simple_label)
        ax.set_aspect(9, anchor="NW")
        ax.set_ylim(0, 1.0)
        lib.plot.rotate_xticklabels(rotation=45, ax=ax)
        ax.set_yticks(np.linspace(0, 1.0, num=6))
        ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, symbol=''))
        ax.set_xlim(d.t.min() - 0.5, d.t.max() + 0.5)
        ax.spines[['right', 'top']].set_visible(False)
    ax.legend(bbox_to_anchor=(1, 1), ncols=2)
            
fig.savefig(f"fig/{species}.een_strain_transfer_expt_plot.paper_style.pdf")