## Preamble

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

### Imports

In [None]:
import os
import subprocess
import time
from itertools import chain, product
from tempfile import mkstemp

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import sfacts as sf
import statsmodels.formula.api as smf
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
import lib.thisproject.data
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

In [None]:
een_metaphlan_rabund = (
    pd.read_table("raw/een-mgen/2023-06-13_aritra.mahapatra@tum.de/6_species.tab")
    .set_index("clade_name")
    .T
    / 100
)
een_metaphlan_rabund_ecoli = een_metaphlan_rabund[
    "k__Bacteria|p__Proteobacteria|c__Gammaproteobacteria|o__Enterobacterales|f__Enterobacteriaceae|g__Escherichia|s__Escherichia_coli"
]
een_metaphlan_rabund_ecoli

In [None]:
# Recreation of one of Aritra's plots. Looks identical.
plt.hist((een_metaphlan_rabund > 0.001).sum(1), bins=20)

In [None]:
pd.read_table(
    "data/group/een/r.proc.gtpro.species_depth.tsv", index_col=["sample", "species_id"]
).depth.unstack(fill_value=0).sum(1).mean()

In [None]:
hmp2_ecoli_depth = pd.read_table(
    "data/group/xjin_hmp2/r.proc.gtpro.species_depth.tsv",
    index_col=["sample", "species_id"],
).depth.unstack(fill_value=0)[102506]
een_ecoli_depth = pd.read_table(
    "data/group/een/r.proc.gtpro.species_depth.tsv", index_col=["sample", "species_id"]
).depth.unstack(fill_value=0)[102506]

fig, ax = plt.subplots()
bins = [0] + list(np.logspace(-3, 3))

for label, (x, color) in dict(
    hmp2=(hmp2_ecoli_depth, "tab:blue"), een=(een_ecoli_depth, "tab:orange")
).items():
    ax.hist(x, bins=bins, label=label, alpha=0.6, color=color)

ax.legend()
ax.set_xscale("symlog", linthresh=1e-3, linscale=0.1)
ax.set_yscale("log")
ax.set_ylabel(f"{label} samples")
ax.set_xlabel("depth")
None

In [None]:
species_rabund = (
    pd.read_table(
        "data/group/een/r.proc.gtpro.species_depth.tsv",
        index_col=["sample", "species_id"],
    )
    .depth.unstack(fill_value=0)
    .apply(lambda x: x / x.sum(), axis=1)
)

In [None]:
hmp2_ecoli_rabund = (
    pd.read_table(
        "data/group/xjin_hmp2/r.proc.gtpro.species_depth.tsv",
        index_col=["sample", "species_id"],
    )
    .depth.unstack(fill_value=0)
    .apply(lambda x: x / x.sum(), axis=1)[102506]
)
een_ecoli_rabund = species_rabund[102506]

fig, ax = plt.subplots()
bins = [0] + list(np.logspace(-7, 1))

for label, (x, color) in dict(
    hmp2=(hmp2_ecoli_rabund, "grey"),
    een=(een_ecoli_rabund, "tab:blue"),
    metaphlan=(een_metaphlan_rabund_ecoli, "tab:orange"),
).items():
    ax.hist(x, bins=bins, label=label, alpha=0.6, color=color)

ax.legend()
ax.set_xscale("symlog", linthresh=1e-7, linscale=0.1)
ax.set_yscale("log")
ax.set_ylabel(f"{label} samples")
ax.set_xlabel("relative abundance")
None

In [None]:
x, y = align_indexes(een_ecoli_rabund, een_metaphlan_rabund_ecoli)


left_bound = 0.0
bins = [0] + list(np.logspace(-7, 0, num=20))

fig, ax = plt.subplots()
*_, cbar_artist = ax.hist2d(
    x, y, bins=bins, norm=mpl.colors.PowerNorm(1 / 2), cmap="magma_r"
)
ax.set_aspect(1)

ax.set_xscale("symlog", linthresh=1e-7, linscale=1)
ax.set_yscale("symlog", linthresh=1e-7, linscale=1)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)

fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.025, 0.67])
fig.colorbar(cbar_artist, cax=cbar_ax, label="count samples")

In [None]:
x, y = align_indexes(een_ecoli_rabund, een_metaphlan_rabund_ecoli)


fig, ax = plt.subplots()

ax.scatter(x, y, alpha=0.5)
ax.set_xlabel("relative abundance (GT-Pro)")
ax.set_ylabel("relative abundance (MetaPhlAn)")
ax.set_xscale("symlog", linthresh=1e-7, linscale=0.1)
ax.set_yscale("symlog", linthresh=1e-7, linscale=0.1)
ax.axvline(1e-3, linestyle=":", color="k", lw=1)
ax.axhline(1e-3, linestyle=":", color="k", lw=1)

ax.set_aspect(1)
ax.plot([0, 1], [0, 1], linestyle="--", color="k")

In [None]:
species_depth = (
    pd.read_table(
        "data/group/een/r.proc.gtpro.species_depth.tsv",
        index_col=["sample", "species_id"],
    )
    .depth.unstack(fill_value=0)
    .rename(columns=str)
)

bins = np.linspace(0, 30_000, num=200)

fig, axs = plt.subplots(2, sharex=True)

for (title, x), ax in zip(
    dict(
        total_depth_by_sample=species_depth.sum(1),
        total_depth_by_species=species_depth.sum(0),
    ).items(),
    axs.flatten(),
):
    ax.hist(x, bins=np.logspace(-1, 5, num=100))
    ax.set_title(title)
    ax.set_xscale("log")
fig.tight_layout()

In [None]:
species_depth.loc[["CF_1", "CF_11", "CF_15", "CF_89"]].sum(1)

In [None]:
species_rabund = species_depth.divide(species_depth.sum(1), axis=0)
species_rabund.mean().sort_values(ascending=False).head(20)

In [None]:
bins = [0] + list(np.logspace(-6, 4, num=100))
plt.hist(species_depth.values.flatten(), bins=bins)
plt.xscale("symlog", linthresh=1e-6, linscale=0.1)
plt.yscale("log")

In [None]:
x = (species_depth > 1e-1).sum(1)
print(x.quantile([0.05, 0.25, 0.5, 0.75, 0.95]))
plt.hist(x, bins=10)
plt.xlabel("Number of species with depth >0.1x")
plt.ylabel("Number of samples")

In [None]:
species_prevalence = (species_depth > 1e-1).mean()

print((species_prevalence > 0.5).sum())
print(((species_depth > 1e-1).sum() >= 2).sum())

plt.hist(species_prevalence, bins=np.linspace(0, 1, num=51))
plt.xlabel("Fraction of samples with depth >0.1x")
plt.ylabel("Number of species")
None

In [None]:
species_rabund = species_depth.divide(species_depth.sum(1), axis=0)

In [None]:
plt.hist((species_rabund > 0.001).sum(1), bins=20)
plt.xlabel("Number of species with relative abundance >0.1%")
plt.ylabel("Number of samples")

In [None]:
(species_rabund > 0.001).sum(1).median()

In [None]:
(species_depth > 1e-1).sum(1).median()

In [None]:
(een_metaphlan_rabund > 0.001).sum(1).median()

In [None]:
n_species = 10
top_species = (
    (species_rabund > 1e-5).sum().sort_values(ascending=False).head(n_species).index
)

fig, axs = plt.subplots(
    n_species, figsize=(5, 0.3 * n_species), sharex=True, sharey=True
)

bins = np.logspace(-8, 1, num=51)

for species_id, ax in zip(top_species, axs):
    # ax.hist(rabund_subset[species_id], bins=bins, alpha=0.7)
    ax.hist(species_rabund[species_id], bins=bins, alpha=0.7)
    ax.set_xscale("log")
    prevalence = (species_rabund[species_id] > 1e-5).mean()
    ax.set_title("")
    # ax.set_xticks()
    # ax.set_yticks()
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.patch.set_alpha(0.0)
    for spine in ["left", "right", "top", "bottom"]:
        ax.spines[spine].set_visible(False)
    ax.annotate(
        f"{species_id} ({prevalence:0.0%})",
        xy=(0.05, 0.1),
        ha="left",
        xycoords="axes fraction",
    )
    ax.set_xlim(left=1e-9)
    ax.set_ylim(top=20)
    ax.axvline(1e-5, lw=1, linestyle=":", color="k")

ax.xaxis.set_visible(True)
ax.spines["bottom"].set_visible(True)
ax.set_xticks([1e-4, 1e-2, 1e-0])
ax.set_xticklabels(["0.01%", "1%", "100%"])
ax.set_xlabel("Relative Abundance")

# fig.subplots_adjust(hspace=-0.75)

In [None]:
species_taxonomy = lib.thisproject.data.load_species_taxonomy(
    "ref/gtpro/species_taxonomy_ext.tsv"
)

In [None]:
for _species_id in top_species.astype(str):
    print(_species_id, ":", species_taxonomy.taxonomy_string.loc[_species_id])

In [None]:
species_taxonomy[lambda x: x.s__.str.endswith('hansenii')]

In [None]:
for _species_id in ['102544', '102506', '101303', '100150', '102330', '101704']:
    print(_species_id, (species_rabund[_species_id] > 0.0001).mean().round(2), (species_rabund[_species_id] > 0.001).mean().round(2), species_taxonomy.loc[_species_id].s__, sep='\t')

In [None]:
(een_metaphlan_rabund_ecoli > 1e-3).mean()

In [None]:
n_species = 20
top_species = (
    (species_rabund > 1e-3).sum().sort_values(ascending=False).head(n_species).index
)

fig, axs = plt.subplots(
    n_species, figsize=(5, 0.3 * n_species), sharex=True, sharey=True
)

bins = np.logspace(-8, 1, num=51)

for species_id, ax in zip(top_species, axs):
    # ax.hist(rabund_subset[species_id], bins=bins, alpha=0.7)
    ax.hist(species_rabund[species_id], bins=bins, alpha=0.7)
    ax.set_xscale("log")
    prevalence = (species_rabund[species_id] > 1e-3).mean()
    ax.set_title("")
    # ax.set_xticks()
    # ax.set_yticks()
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.patch.set_alpha(0.0)
    for spine in ["left", "right", "top", "bottom"]:
        ax.spines[spine].set_visible(False)
    ax.annotate(
        f"{species_id} ({prevalence:0.0%})",
        xy=(0.05, 0.1),
        ha="left",
        xycoords="axes fraction",
    )
    ax.set_xlim(left=1e-9)
    ax.set_ylim(top=20)
    ax.axvline(1e-5, lw=1, linestyle=":", color="k")

ax.xaxis.set_visible(True)
ax.spines["bottom"].set_visible(True)
ax.set_xticks([1e-4, 1e-2, 1e-0])
ax.set_xticklabels(["0.01%", "1%", "100%"])
ax.set_xlabel("Relative Abundance")

# fig.subplots_adjust(hspace=-0.75)

In [None]:
strain_depth = []
missing_files = []
for species_id in species_depth.columns:
    path = f'data/group/een/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.comm.tsv'
    try:
        d = pd.read_table(path, index_col=['sample', 'strain']).squeeze().unstack()
    except FileNotFoundError:
        missing_files.append(path)
        d = pd.DataFrame([])
    _keep_strains = idxwhere(d.sum() > 0.05)
    d = d.reindex(index=species_depth.index, columns=_keep_strains, fill_value=0)
    d = d.assign(__other=lambda x: 1 - x.sum(1)).rename(columns={'__other': -1})
    d[d < 0] = 0
    d = d.divide(d.sum(1), axis=0)
    d = d.multiply(species_depth[species_id], axis=0)
    d = d.rename(columns=lambda s: f"{species_id}_{s}")
    strain_depth.append(d)
strain_depth = pd.concat(strain_depth, axis=1)
strain_rabund = strain_depth.divide(strain_depth.sum(1), axis=0)
len(species_depth.columns), len(missing_files)

In [None]:
plt.hist(strain_rabund.values.flatten() + 1e-10, bins=np.logspace(-10, 0))
plt.yscale('log')
plt.xscale('log')

In [None]:
for _species_id in top_species.astype(str):
    print(_species_id, ":", species_taxonomy.taxonomy_string.loc[_species_id])

In [None]:
species_id = "100099"
species_taxonomy.loc[species_id]

In [None]:
np.random.seed(0)

mgtp_all = sf.data.Metagenotype.load(
    f"data/group/een/species/sp-{species_id}/r.proc.gtpro.mgtp.nc"
)
world = sf.data.World.load(
    f"data/group/een/species/sp-{species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts37-s85-seed0.world.nc"
)
position_ss = world.random_sample(position=min(1000, world.sizes["position"])).position

strain_frac = world.drop_low_abundance_strains(0.05).community.to_pandas()
abundant_strain_list = list(strain_frac.columns)
abundant_strain_list.remove(-1)  # Drop "other" strain.
strain_palette = lib.plot.construct_ordered_palette(abundant_strain_list, cm='rainbow')

In [None]:
sf.plot.plot_metagenotype(
    world.sel(position=position_ss),
    col_linkage_func=lambda w: w.metagenotype.linkage(),
    dwidth=0.01,
    scalex=0.147,
    col_colors_func=None,
)
sf.plot.plot_community(
    world.sel(position=position_ss),
    col_linkage_func=lambda w: w.metagenotype.linkage(),
    row_linkage_func=lambda w: w.genotype.linkage(),
    dwidth=0.01,
    row_colors_func=None,
)
sf.plot.plot_genotype(
    world.sel(position=position_ss),
    row_linkage_func=lambda w: w.genotype.linkage(),
    row_colors_func=lambda w: w.genotype.entropy(),
)

In [None]:
meta = (
    pd.read_table("meta/een-mgen/stool.tsv")
    # .rename(columns={'Seq-Name': 'sample', 'CED/Patient-recoded': 'subject_id', 'sampleDate': 'date', 'Diet (=PreEEN, EEN, PostEEN)': 'sample_type'})
    # .assign(
    #     date=lambda x: pd.to_datetime(x.date),
    #     sample_type=lambda x: x.sample_type.fillna('???')
    # )
    .set_index("mgen_id")
    # .sort_values(['subject_id', 'date', 'sample_type'])
)

# # FIXME: Metadata seems to include a swap in the metagenomic data of CF_11 and CF_15.
# meta = meta.rename({'CF_11': 'CF_15', 'CF_15': 'CF_11'})
meta

In [None]:
micro_meta = pd.read_table("meta/een-mgen/microcosm.tsv").set_index("mgen_id").rename(columns={'inoculum_subject_id': 'subject_id'})
micro_meta

In [None]:
assert (world.sample.isin(micro_meta.index) | world.sample.isin(meta.index)).all()

In [None]:
assert meta.index.isin(species_rabund.index).all()

In [None]:
meta[["subject_id", "sample_type"]].value_counts().unstack(fill_value=0).sort_values(
    "EEN", ascending=False
).head(20)

In [None]:
_species_rabund = species_rabund[species_id]
_frac = strain_frac
_strain_list = idxwhere(_frac.max() > 0.05)
d0 = (
    pd.concat(
        [
            meta.assign(
                label=lambda x: x.assign(idx=x.index)[
                    ["idx", "collection_date_relative_een_end", "sample_type"]
                ].apply(tuple, axis=1)
            ),
            micro_meta.assign(
                collection_date_relative_een_end=np.inf,
                sample_type=lambda x: x.inoculum_mgen_id.map(lambda s: f"invitro[{s}]"),
                label=lambda x: x.assign(idx=x.index)[["idx", "sample_type"]].apply(
                    tuple, axis=1
                ),
            ),
        ]
    )
    .assign(
        # has_strain_deconv=lambda x: x.index.isin(_comm.index),
        species_rabund=_species_rabund,
    )
    .join(_frac)
    .sort_values(["subject_id", "collection_date_relative_een_end", "sample_type"])
)

subject_list = [
    # "A",
    # "B",
    "C",
    "D",
    "E",
    "F",
    "G",
    # "H",
    "K",
    "L",
    "M",
    "N",
    "O",
    "P",
    "Q",
    "R",
    "S",
    "T",
    "U",
]
fig, axs = lib.plot.subplots_grid(
    ncols=4, naxes=len(subject_list), ax_width=5, ax_height=4,
)
for subject_id, ax in zip(subject_list, axs.flatten()):
    ax.set_title(subject_id)
    twin_ax = ax.twinx()
    d1 = d0[lambda x: x.subject_id == subject_id].set_index("label")
    d1[_strain_list].plot(
        kind="bar",
        width=0.95,
        stacked=True,
        color=strain_palette,
        ax=ax,
    )
    d1.species_rabund.plot(kind="line", ax=twin_ax, color="k")
    ax.legend_.set_visible(False)
    ax.set_ylim(0, 1)
    ax.set_ylabel("strain fraction")
    ax.set_xlabel("")
    twin_ax.set_ylabel("species relative abundance")
    twin_ax.set_ylim(0)
    lib.plot.rotate_xticklabels(ax)
    ax.set_xlim(-0.5, 15.5)
fig.tight_layout()


subject_list = [
    "A",
    "B",
    "H",
]
fig, axs = lib.plot.subplots_grid(
    ncols=3, naxes=len(subject_list), ax_width=5, ax_height=4,
)
for subject_id, ax in zip(subject_list, axs.flatten()):
    ax.set_title(subject_id)
    twin_ax = ax.twinx()
    d1 = d0[lambda x: x.subject_id == subject_id].set_index("label")
    d1[_strain_list].plot(
        kind="bar",
        width=0.95,
        stacked=True,
        color=strain_palette,
        ax=ax,
    )
    d1.species_rabund.plot(kind="line", ax=twin_ax, color="k")
    ax.legend_.set_visible(False)
    ax.set_ylim(0, 1)
    ax.set_ylabel("strain fraction")
    ax.set_xlabel("")
    twin_ax.set_ylabel("species relative abundance")
    twin_ax.set_ylim(0)
    lib.plot.rotate_xticklabels(ax)
    ax.set_xlim(-0.5, 15.5)
fig.tight_layout()

In [None]:
d = all_subject_data[lambda x: x.is_previous_een_tp | x.is_next_post_een_tp]
_subject_list = (
    d[["subject_id", "is_previous_een_tp", "is_next_post_een_tp"]]
    .groupby("subject_id")
    .sum()[lambda x: (x.is_previous_een_tp == 1) & (x.is_next_post_een_tp == 1)]
    .index
)
d[lambda x: x.subject_id.isin(_subject_list)].set_index(
    ["subject_id", "is_next_post_een_tp"]
)[["days_since_last_een_tp", "strain_bc_to_last_een"]].unstack("is_next_post_een_tp")

In [None]:
fig, axs = plt.subplots(2, figsize=(10, 5), sharex=True)

subject_list = [
    "A",
    "B",
    "C",
    "D",
    "E",
    "F",
    "G",
    "H",
    "K",
    "L",
    "M",
    "N",
    "O",
    # 'P',
    "Q",
    "R",
    "S",
    "T",
    "U",
]
subject_palette = lib.plot.construct_ordered_palette(subject_list, cm="rainbow")
sample_type_palette = {"EEN": "o", "PostEEN": "<", "PreEEN": "."}

all_subject_data = []
for subject_id in subject_list:
    d0 = meta[meta.subject_id == subject_id].assign(
        has_strain_info=lambda x: x.index.isin(world.sample.values),
        species_rabund=species_rabund[species_id],
        species_depth=species_depth[species_id],
        label=lambda x: x[["collection_date_relative_een_end", "sample_type"]].apply(
            tuple, axis=1
        ),
        last_een_tp_days=lambda x: x[
            (x.sample_type == "EEN") & x.has_strain_info
        ].collection_date_relative_een_end.max(),
        is_last_een_tp=lambda x: x.collection_date_relative_een_end
        == x.last_een_tp_days,
        cbrt_days_since_end_een=lambda x: np.cbrt(x.collection_date_relative_een_end),
        days_since_last_een_tp=lambda x: x.collection_date_relative_een_end
        - x.last_een_tp_days,
        is_previous_een_tp=lambda x: x.index.isin(
            x[lambda x: x.days_since_last_een_tp < 0]
            .days_since_last_een_tp.sort_values()
            .index[-1:]
        ),
        is_next_post_een_tp=lambda x: x.index.isin(
            x[lambda x: x.days_since_last_een_tp > 0]
            .days_since_last_een_tp.sort_values()
            .index[:1]
        ),
    )

    w = world.sel(sample=idxwhere(d0.has_strain_info))

    # strain_mean_rabund_by_sample_type = w.community.to_pandas().groupby(d0.sample_type).mean()
    strain_bc_to_last_een = pd.Series(
        sp.spatial.distance.cdist(
            w.community.to_pandas().loc[idxwhere(d0.is_last_een_tp)],
            w.community.to_pandas(),
            metric="braycurtis",
        )[0],
        index=w.sample.values,
    )
    d1 = d0.assign(strain_bc_to_last_een=strain_bc_to_last_een).dropna(
        subset=["days_since_last_een_tp", "strain_bc_to_last_een"]
    )
    all_subject_data.append(d1)

    axs[0].plot(
        "days_since_last_een_tp",
        "strain_bc_to_last_een",
        data=d1,
        label=subject_id,
        color=subject_palette[subject_id],
    )
    axs[1].plot(
        "days_since_last_een_tp",
        "species_rabund",
        data=d1,
        label=subject_id,
        color=subject_palette[subject_id],
    )
    for sample_type, d2 in d1.groupby("sample_type"):
        axs[0].scatter(
            "days_since_last_een_tp",
            "strain_bc_to_last_een",
            data=d2,
            marker=sample_type_palette[sample_type],
            color=subject_palette[subject_id],
            label="__nolegend__",
        )
        axs[1].scatter(
            "days_since_last_een_tp",
            "species_rabund",
            data=d2,
            marker=sample_type_palette[sample_type],
            color=subject_palette[subject_id],
            label="__nolegend__",
        )

all_subject_data = pd.concat(all_subject_data)
axs[0].set_xscale("symlog", linthresh=1, linscale=0.1)
axs[1].set_yscale("log")
# plt.xlim(-50, 50)
axs[0].axvline(0, linestyle="--", color="k", lw=1)
axs[0].set_xlim(-300, 300)

axs[0].legend(bbox_to_anchor=(1.2, 0.1))

In [None]:
_meta = meta
_comp = strain_rabund > 1e-5
_metric = 'jaccard'

fig, ax = plt.subplots(figsize=(10, 3), sharex=True)

subject_list = [
    "A",
    "B",
    "C",
    "D",
    "E",
    "F",
    "G",
    "H",
    "K",
    "L",
    "M",
    "N",
    "O",
    # 'P',
    "Q",
    "R",
    "S",
    "T",
    "U",
]
subject_palette = lib.plot.construct_ordered_palette(subject_list, cm="rainbow")
sample_type_palette = {"EEN": "o", "PostEEN": "<", "PreEEN": "."}

all_subject_data = []
for subject_id in subject_list:
    d0 = _meta[_meta.subject_id == subject_id].assign(
        has_strain_info=lambda x: x.index.isin(world.sample.values),
        label=lambda x: x[["collection_date_relative_een_end", "sample_type"]].apply(
            tuple, axis=1
        ),
        last_een_tp_days=lambda x: x[
            (x.sample_type == "EEN") & x.has_strain_info
        ].collection_date_relative_een_end.max(),
        is_last_een_tp=lambda x: x.collection_date_relative_een_end
        == x.last_een_tp_days,
        cbrt_days_since_end_een=lambda x: np.cbrt(x.collection_date_relative_een_end),
        days_since_last_een_tp=lambda x: x.collection_date_relative_een_end
        - x.last_een_tp_days,
        is_previous_een_tp=lambda x: x.index.isin(
            x[lambda x: x.days_since_last_een_tp < 0]
            .days_since_last_een_tp.sort_values()
            .index[-1:]
        ),
        is_next_post_een_tp=lambda x: x.index.isin(
            x[lambda x: x.days_since_last_een_tp > 0]
            .days_since_last_een_tp.sort_values()
            .index[:1]
        ),
    )

    diss_to_last_een = pd.Series(
        sp.spatial.distance.cdist(
            _comp.loc[idxwhere(d0.is_last_een_tp)],
            _comp.loc[d0.index],
            metric=_metric,
        )[0],
        index=d0.index,
    )
    d1 = d0.assign(diss_to_last_een=diss_to_last_een).dropna(
        subset=["days_since_last_een_tp", "diss_to_last_een"]
    )
    all_subject_data.append(d1)

    ax.plot(
        "days_since_last_een_tp",
        "diss_to_last_een",
        data=d1,
        label=subject_id,
        color=subject_palette[subject_id],
    )
    for sample_type, d2 in d1.groupby("sample_type"):
        ax.scatter(
            "days_since_last_een_tp",
            "diss_to_last_een",
            data=d2,
            marker=sample_type_palette[sample_type],
            color=subject_palette[subject_id],
            label="__nolegend__",
        )

all_subject_data = pd.concat(all_subject_data)
ax.set_xscale("symlog", linthresh=1, linscale=0.1)
ax.axvline(0, linestyle="--", color="k", lw=1)
ax.set_xlim(-300, 300)
ax.legend(bbox_to_anchor=(1.0, 1.0))

# 
fig, ax = plt.subplots()
d3 = all_subject_data[lambda x: x.is_previous_een_tp | x.is_next_post_een_tp]
_subject_list = (
    d3[["subject_id", "is_previous_een_tp", "is_next_post_een_tp"]]
    .groupby("subject_id")
    .sum()[lambda x: (x.is_previous_een_tp == 1) & (x.is_next_post_een_tp == 1)]
    .index
)
d4 = d3.loc[lambda x: x.subject_id.isin(_subject_list)]
sns.regplot(x='days_since_last_een_tp', y='diss_to_last_een', data=d4, ax=ax)


fig, ax = plt.subplots()
sns.stripplot(x='is_next_post_een_tp', y='diss_to_last_een', data=d3, ax=ax)

In [None]:
d

In [None]:
_meta = pd.concat([meta, micro_meta.assign(is_micro=True)])
_comp = strain_rabund.loc[_meta.index] > 1e-5

num_shared_taxa = np.empty((_comp.shape[0], _comp.shape[0]))
for (i, sampleA), (j, sampleB) in product(enumerate(_comp.index), repeat=2):
    num_shared_taxa[i, j] = (_comp.loc[sampleA] & _comp.loc[sampleB]).sum()
num_shared_taxa = pd.DataFrame(num_shared_taxa, index=_comp.index, columns=_comp.index)


subject_palette = lib.plot.construct_ordered_palette(_meta.subject_id.unique(), cm='tab20')

_colors = pd.DataFrame(dict(
    subj=_meta.subject_id.map(subject_palette),
    is_micro=_meta.is_micro.fillna(False).map({True: 'black', False: 'grey'}),
))
sns.clustermap(num_shared_taxa, col_colors=_colors, row_colors=_colors, norm=mpl.colors.PowerNorm(1/3))

In [None]:
for subject_id in subject_palette:
    plt.scatter([], [], c=subject_palette[subject_id], label=subject_id)
plt.legend(ncols=3)

meta[lambda x: x.subject_id.isin(['B', 'C'])]

In [None]:
_meta = pd.concat([meta, micro_meta.assign(is_micro=True)])[lambda x: x.subject_id.isin(['B', 'C'])]
_comp = strain_rabund.loc[_meta.index] > 1e-5

num_shared_taxa = np.empty((_comp.shape[0], _comp.shape[0]))
for (i, sampleA), (j, sampleB) in product(enumerate(_comp.index), repeat=2):
    num_shared_taxa[i, j] = (_comp.loc[sampleA] & _comp.loc[sampleB]).sum()
num_shared_taxa = pd.DataFrame(num_shared_taxa, index=_comp.index, columns=_comp.index)

_colors = pd.DataFrame(dict(
    subj=_meta.subject_id.map(subject_palette),
    is_micro=_meta.is_micro.fillna(False).map({True: 'black', False: 'grey'}),
))
sns.clustermap(num_shared_taxa, col_colors=_colors, row_colors=_colors, norm=mpl.colors.PowerNorm(1/2))

In [None]:
_meta = pd.concat([meta, micro_meta.assign(is_passaged=True)])[lambda x: x.subject_id.isin(['B', 'C'])]
_comp = strain_rabund.loc[_meta.index] > 1e-5
_metric = 'jaccard'

diss = pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(_comp, metric=_metric)), index=_comp.index, columns=_comp.index)

_colors = pd.DataFrame(dict(
    subject=_meta.subject_id.map(subject_palette),
    is_passaged=_meta.is_passaged.fillna(False).map({True: 'black', False: 'grey'}),
))
sns.clustermap(diss, col_colors=_colors, row_colors=_colors, norm=mpl.colors.PowerNorm(1/1))

In [None]:
# CF_15 and CF_11 swapped?

In [None]:
_meta = pd.concat([meta, micro_meta.assign(is_micro=True)])[lambda x: x.subject_id.isin(['B', 'C'])]
_comp = species_rabund.loc[_meta.index] > 1e-5

num_shared_taxa = np.empty((_comp.shape[0], _comp.shape[0]))
for (i, sampleA), (j, sampleB) in product(enumerate(_comp.index), repeat=2):
    num_shared_taxa[i, j] = (_comp.loc[sampleA] & _comp.loc[sampleB]).sum()
num_shared_taxa = pd.DataFrame(num_shared_taxa, index=_comp.index, columns=_comp.index)

_colors = pd.DataFrame(dict(
    subj=_meta.subject_id.map(subject_palette),
    is_micro=_meta.is_micro.fillna(False).map({True: 'black', False: 'grey'}),
))
sns.clustermap(num_shared_taxa, col_colors=_colors, row_colors=_colors, norm=mpl.colors.PowerNorm(1/3))

In [None]:
_meta = pd.concat([meta, micro_meta.assign(is_micro=True)])
_comp = strain_rabund.loc[_meta.index] > 1e-5
_metric = 'jaccard'

subject_palette = lib.plot.construct_ordered_palette(_meta.subject_id.unique(), cm='tab20')

_colors = pd.DataFrame(dict(
    subj=_meta.subject_id.map(subject_palette),
    is_micro=_meta.is_micro.fillna(False).map({True: 'black', False: 'grey'}),
))
d = pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(_comp, metric=_metric)), index=_comp.index, columns=_comp.index)
sns.clustermap(d, col_colors=_colors, row_colors=_colors)

In [None]:
d5 = d4.set_index(['subject_id', 'is_previous_een_tp']).diss_to_last_een.unstack()
sp.stats.ttest_ind(d5[False], d5[True])

In [None]:
x = world.community.to_pandas()
sample_list = list(set(meta.index) & set(x.index))
m = meta.loc[sample_list]

subject_palette = lib.plot.construct_ordered_palette(m.subject_id.unique())

d0 = []

for subject_id, m1 in m.groupby("subject_id"):
    d0.append(
        pd.DataFrame(
            dict(
                subject_id=subject_id,
                bc_diss=sp.spatial.distance.pdist(x.loc[m1.index], metric="braycurtis"),
                delta_time=sp.spatial.distance.pdist(
                    m1.collection_date_relative_een_end.to_frame(),
                    metric=lambda x, y: np.abs(x - y),
                ),
                transition_type=(
                    pd.Series(
                        sp.spatial.distance.pdist(
                            m1.sample_type.map(
                                {"PreEEN": 2, "EEN": 3, "PostEEN": 5}
                            ).to_frame(),
                            metric=lambda x, y: x * y,
                        )
                    )
                    .astype(int)
                    .map(
                        {
                            2 * 2: "pre",
                            2 * 3: "pre->during",
                            2 * 5: "pre->post",
                            3 * 3: "during",
                            3 * 5: "during->post",
                            5 * 5: "post->post",
                        }
                    )
                ),
            )
        )
    )

d0 = pd.concat(d0)

num_transition_types = len(d0.transition_type.unique())
fig, axs = lib.plot.subplots_grid(
    ncols=2, naxes=num_transition_types, sharex=True, sharey=True
)
for (transition_type, d1), ax in zip(d0.groupby("transition_type"), axs.flatten()):
    ax.scatter(
        x="delta_time",
        y="bc_diss",
        data=d1.assign(subject_color=lambda x: x.subject_id.map(subject_palette)),
        c="subject_color",
    )
    ax.set_title(transition_type)

ax.set_xscale("log")
ax.set_xlim(1)

In [None]:
d1 = (
    d0.groupby(["subject_id", "transition_type"])[["bc_diss", "delta_time"]]
    .mean()
    .reset_index()
)

sns.stripplot(x="transition_type", hue="subject_id", y="bc_diss", data=d1)
plt.legend(bbox_to_anchor=(1, 1))

In [None]:
d1 = (
    d0.groupby(["subject_id", "transition_type"])[["bc_diss", "delta_time"]]
    .mean()
    .reset_index()
)

sns.stripplot(x="transition_type", hue="subject_id", y="delta_time", data=d1)
plt.legend(bbox_to_anchor=(1, 1))

In [None]:
import pymc as pm
import pytensor as pt
from patsy import dmatrices

bc_padding = 0.1

d0 = []
for subject_id, m1 in m.groupby("subject_id"):
    d0.append(
        pd.DataFrame(
            dict(
                subject_id=subject_id,
                bc_diss=sp.spatial.distance.pdist(x.loc[m1.index], metric="braycurtis"),
                delta_time=sp.spatial.distance.pdist(
                    m1.collection_date_relative_een_end.to_frame(),
                    metric=lambda x, y: np.abs(x - y),
                ),
                transition_type=(
                    pd.Series(
                        sp.spatial.distance.pdist(
                            m1.sample_type.map(
                                {"PreEEN": 2, "EEN": 3, "PostEEN": 5}
                            ).to_frame(),
                            metric=lambda x, y: x * y,
                        )
                    )
                    .astype(int)
                    .map(
                        {
                            2 * 2: "pre",
                            2 * 3: "pre->during",
                            2 * 5: "pre->post",
                            3 * 3: "during",
                            3 * 5: "during->post",
                            5 * 5: "post->post",
                        }
                    )
                ),
            )
        )
    )
d0 = pd.concat(d0)

d1 = d0.assign(
    trnsfm_delta_time=lambda x: x.delta_time.pipe(lambda x: np.log(x)),
    trnsfm_bc_diss=lambda x: x.bc_diss.pipe(
        lambda x: sp.special.logit((x + bc_padding) / (1 + 2 * bc_padding))
    ),
    subject_color=lambda x: x.subject_id.map(subject_palette),
)
y, X = dmatrices(
    'trnsfm_bc_diss ~ np.sqrt(delta_time) + C(subject_id, Sum) + C(transition_type, Treatment("during"))',
    data=d1,
    return_type="dataframe",
)

with pm.Model() as model0:
    _x = pm.Data(
        "_x",
        X,
        dims=("pair", "covariate"),
        coords=dict(pair=X.index, covariate=X.columns),
    )
    _y = pm.Data("_y", y.squeeze(), dims=("pair"), coords=dict(pair=y.index))

    sigma = pm.HalfCauchy("sigma", beta=10)
    beta = pm.Normal("beta", dims=("covariate",), sigma=10)
    obs = pm.Normal("obs", mu=_x @ beta, sigma=sigma, observed=_y, dims=("pair"))

    trace0 = pm.sample()

In [None]:
pm.summary(trace0)

In [None]:
x = species_rabund
sample_list = list(set(meta.index) & set(x.index))
m = meta.loc[sample_list]

subject_palette = lib.plot.construct_ordered_palette(m.subject_id.unique())

d0 = []

for subject_id, m1 in m.groupby("subject_id"):
    d0.append(
        pd.DataFrame(
            dict(
                subject_id=subject_id,
                bc_diss=sp.spatial.distance.pdist(x.loc[m1.index], metric="braycurtis"),
                delta_time=sp.spatial.distance.pdist(
                    m1.collection_date_relative_een_end.to_frame(),
                    metric=lambda x, y: np.abs(x - y),
                ),
                transition_type=(
                    pd.Series(
                        sp.spatial.distance.pdist(
                            m1.sample_type.map(
                                {"PreEEN": 2, "EEN": 3, "PostEEN": 5}
                            ).to_frame(),
                            metric=lambda x, y: x * y,
                        )
                    )
                    .astype(int)
                    .map(
                        {
                            2 * 2: "pre",
                            2 * 3: "pre->during",
                            2 * 5: "pre->post",
                            3 * 3: "during",
                            3 * 5: "during->post",
                            5 * 5: "post->post",
                        }
                    )
                ),
            )
        )
    )

d0 = pd.concat(d0)

num_transition_types = len(d0.transition_type.unique())
fig, axs = lib.plot.subplots_grid(
    ncols=2, naxes=num_transition_types, sharex=True, sharey=True
)
for (transition_type, d1), ax in zip(d0.groupby("transition_type"), axs.flatten()):
    ax.scatter(
        x="delta_time",
        y="bc_diss",
        data=d1.assign(subject_color=lambda x: x.subject_id.map(subject_palette)),
        c="subject_color",
    )
    ax.set_title(transition_type)

ax.set_xscale("log")
ax.set_xlim(1)

In [None]:
for species_id in 

In [None]:
strain_bc_to_last_een = pd.Series(
    sp.spatial.distance.cdist(
        w.community.to_pandas().loc[idxwhere(d1.is_last_een_tp)],
        w.community.to_pandas(),
        metric="braycurtis",
    )[0],
    index=w.sample.values,
)

In [None]:
sf.plot.plot_genotype(
    sf.Genotype.concat(
        dict(
            strain=world.drop_low_abundance_strains(0.05).genotype,
            mgen=world.metagenotype.to_estimated_genotype(),
        ),
        dim="strain",
    ).sel(position=position_ss),
    transpose=True,
)

In [None]:
sample = "CF_060"

In [None]:
world.metagenotype.mean_depth().sel(sample=[sample])

In [None]:
world.community.sel(sample=[sample]).to_series().sort_values(ascending=False).head(4)

In [None]:
sf.plot.plot_metagenotype_frequency_spectrum(world, sample, bins=101)
plt.yscale("log")

In [None]:
sf.plot.plot_metagenotype_frequency_spectrum_compare_samples(
    world, sample_list=["CF_059", "CF_060"]
)

In [None]:
sf.plot.plot_genotype_frequency_spectrum(world, strain=3, bins=100)
plt.yscale("log")

In [None]:
sf.plot.plot_genotype_entropy(
    world.sel(position=position_ss),
    row_colors_func=lambda w: w.genotype.entropy(norm=2),
)
sf.plot.plot_genotype(
    world.sel(position=position_ss),
    row_colors_func=lambda w: w.genotype.entropy(norm=2),
)

In [None]:
plt.hist(world.genotype.entropy(norm=2).to_series().sort_values(ascending=False))

In [None]:
world.genotype.entropy(norm=2).to_series().sort_values(ascending=False).head()

In [None]:
high_entropy_genotype_list = idxwhere(world.genotype.entropy(norm=2).to_series() > 0.25)
high_entropy_genotype_list

In [None]:
low_representation_strain_list = idxwhere(
    world.community.max("sample").to_series() < 0.05
)
low_representation_strain_list

In [None]:
replace_strain_list = list(
    set(low_representation_strain_list + high_entropy_genotype_list)
)
len(replace_strain_list)

In [None]:
geno_init = world.genotype.data.to_pandas().copy()
geno_init.loc[replace_strain_list] = 0.5
geno_init = sf.data.Genotype(geno_init.stack().to_xarray())
sf.plot.plot_genotype(
    geno_init.sel(position=position_ss),
    row_linkage_func=lambda w: world.sel(position=position_ss).genotype.linkage(),
)

In [None]:
comm_init = world.community.data.to_pandas().copy()
comm_init[replace_strain_list] = pd.DataFrame(
    {
        s: (comm_init[replace_strain_list].sum(1) / len(replace_strain_list))
        for s in replace_strain_list
    }
)
comm_init = sf.data.Community(comm_init.stack().to_xarray()).fuzzed(eps=1e-3)
sf.plot.plot_community(
    comm_init,
    row_linkage_func=lambda w: world.sel(position=position_ss).genotype.linkage(),
)