# Preamble

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

## Imports

In [None]:
import os
import subprocess
import time
from itertools import chain, product
from tempfile import mkstemp
from warnings import filterwarnings

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import sfacts as sf
import statsmodels.formula.api as smf
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable

# from fastcluster import linkage
from scipy.cluster.hierarchy import linkage
from scipy.spatial.distance import pdist, squareform, cdist
from statsmodels.graphics.regressionplots import influence_plot
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
import lib.thisproject.data
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

In [None]:
def linkage_order(linkage, labels):
    return labels[sp.cluster.hierarchy.to_tree(linkage).pre_order(lambda x: x.id)]

def label_een_experiment_sample(x):
    if x.sample_type == "human":
        label = f"[{x.name}] {x.collection_date_relative_een_end} {x.diet_or_media}"
    elif x.sample_type in ["Fermenter_inoculum"]:
        label = f"[{x.name}] {x.source_samples} inoc {x.diet_or_media}"
    elif x.sample_type in ["Fermenter"]:
        # label = f"[{x.name}] {x.source_samples} frmnt {x.diet_or_media}"
        label = f"[{x.name}] {x.source_samples} {x.diet_or_media}"
    elif x.sample_type in ["mouse"]:
        if x.status_mouse_inflamed == 'Inflamed':
            # label = f"[{x.name}] {x.source_samples} 🐭 {x.mouse_genotype} {x.diet_or_media} inflam"
            label = f"[{x.name}] {x.source_samples} {x.diet_or_media} inflam"
        elif x.status_mouse_inflamed == 'not_Inflamed':
            # label = f"[{x.name}] {x.source_samples} 🐭 {x.mouse_genotype} {x.diet_or_media} not_inf"
            label = f"[{x.name}] {x.source_samples} {x.diet_or_media} not_inf"
        else:
            raise ValueError(f"sample type {x.status_mouse_inflamed} not understood")
    else:
        raise ValueError(f"sample type {x.sample_type} not understood")
    return label

In [None]:
sns.set_context("talk")
plt.rcParams["figure.dpi"] = 50

# Prepare Metadata

In [None]:
subject_order = [
    "A",
    "B",
    "H",
    "C",
    "D",
    "E",
    "F",
    "G",
    "K",
    "L",
    "M",
    "N",
    "O",
    "P",
    "Q",
    "R",
    "S",
    "T",
    "U",
]

In [None]:
sample = (
    pd.read_table("meta/een-mgen/sample.tsv")
    .assign(
        label=lambda x: x[
            ["collection_date_relative_een_end", "diet_or_media", "sample_id"]
        ].apply(tuple, axis=1)
    )
    .set_index("sample_id")
    .assign(full_label=lambda d: d.apply(label_een_experiment_sample, axis=1))
)

# Prepare Data

In [None]:
motu_depth = (pd.read_table(
    "data/group/een/r.proc.gene99_new-v22-agg75.spgc_specgene-ref-t25-p95.species_depth.tsv",
    names=['sample', "species_id", 'depth'], index_col=['sample', "species_id"],
    )
    .depth.unstack(fill_value=0)
    .rename(columns=str, index=lambda x: "CF_" + str(int(x.split("_")[1])))
    .rename({'CF_15': 'CF_11', 'CF_11': 'CF_15'})  # Sample swap
)
motu_rabund = motu_depth.divide(motu_depth.sum(1), axis=0)

motu_rabund

# Species Enrichment Analysis

# Strain Time-series

In [None]:
motu_id = "100099"
# rotu_id = "Zotu4"
drop_strains_thresh = 0.5
ylinthresh = 1e-4
bar_width = 1.0

In [None]:
sf_fit = (
    sf.data.World.load(
        f"data/group/een/species/sp-{motu_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts48-s85-seed0.world.nc"
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={"CF_11": "CF_15", "CF_15": "CF_11"})
    .drop_low_abundance_strains(0.01)
    .rename_coords(strain=str)
)
print(sf_fit.sizes)
mgtp_error = sf.evaluation.metagenotype_error2(sf_fit, discretized=False)[1]
entrp_error = sf.evaluation.metagenotype_entropy_error(
    sf_fit, discretized=False, p=1, montecarlo_draws=10
)[1]
comm_entrp = sf_fit.community.entropy().to_series()
# high_mgtp_error = mgtp_error >= 0.1
# high_entrp_error = entrp_error >= 0.2
# high_comm_entrp = comm_entrp >= 1.5

position_ss = sf_fit.random_sample(position=min(2000, sf_fit.sizes['position'])).position

# Genotype similarity ordered palette:
strain_linkage = sf_fit.genotype.linkage(optimal_ordering=True)
sample_linkage = sf_fit.metagenotype.linkage(optimal_ordering=True)
strain_order = list(
    linkage_order(
        strain_linkage,
        sf_fit.strain.values,
    )
)
strain_order.remove("-1")  # Drop "other" strain.
strain_palette = lib.plot.construct_ordered_palette(
    strain_order,
    cm="rainbow",
)

sample_colors = xr.Dataset(dict(mg_err=mgtp_error, en_err=entrp_error, cen=comm_entrp))

sf.plot.plot_metagenotype(
    sf_fit.sel(position=position_ss),
    scalex=0.4,
    col_linkage_func=lambda w: sample_linkage,
    col_colors_func=lambda w: sample_colors,
)

sf.plot.plot_community(
    sf_fit,
    scalex=0.4,
    scaley=0.6,
    row_linkage_func=lambda w: strain_linkage,
    row_colors=sf_fit.strain.to_series().map(strain_palette),
    col_linkage_func=lambda w: sample_linkage,
    col_colors_func=lambda w: sample_colors,
)

In [None]:
d0 = (
    sample.loc[
        lambda x: (
            True
            # x.index.isin(sf_fit.sample.values)
            & x.sample_type.isin(["human", "Fermenter", "mouse"])
            & x.subject_id.isin(["A", "B", "H"])
            # & (x.sample_type == "Fermenter")
        ),
        [
            "subject_id",
            "collection_date_relative_een_end",
            "sample_type",
            "diet_or_media",
            "mouse_genotype",
            "source_samples",
            "status_mouse_inflamed",
            "full_label",
        ],
    ]
    .sort_values(
        [
            "subject_id",
            "collection_date_relative_een_end",
            "sample_type",
            "diet_or_media",
            "mouse_genotype",
            "source_samples",
            "status_mouse_inflamed",
        ]
    )
    .assign(
        rabund=motu_rabund[motu_id],
    )
    .dropna(subset=['rabund'])
)

sample_type_order = ["human", "Fermenter", "mouse"]
subject_order = ["A", "B", "H"]


_grid_sample_counts = (
    d0[["subject_id", "sample_type"]]
    .value_counts()
    .unstack()
    .reindex(columns=sample_type_order)
)
fig, axs = plt.subplots(
    *_grid_sample_counts.shape,
    figsize=(20, 15),
    width_ratios=_grid_sample_counts.max().values,
    sharey=True,
    gridspec_kw=dict(wspace=0.1, hspace=3)
)

for subject_id, ax_row in zip(subject_order, axs):
    d1 = d0[lambda x: (x.subject_id == subject_id)]
    strain_frac_sample_list = list(set(d1.index) & set(sf_fit.sample.values))
    if len(strain_frac_sample_list) == 0:
        print(f"No strain analysis for {subject_id}.")
        comm = []
        _strain_order = []
    else:
        w = (
            sf_fit.sel(sample=strain_frac_sample_list)
            .drop_low_abundance_strains(drop_strains_thresh)
            .rename_coords(strain=str)
        )
        _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
        comm = w.community.to_pandas()
    d2 = d1.join(comm)
    for sample_type, ax in zip(sample_type_order, ax_row):
        d3 = d2[lambda x: (x.sample_type == sample_type)].assign(xpos=lambda x: np.arange(len(x.index)))
        ax.scatter("xpos", "rabund", data=d3, color='k', s=10, label='__nolegend__')
        # ax.set_aspect(700, adjustable="datalim", anchor="NW")
        ax.set_ylim(-1e-3, 1)
        ax.set_yscale("symlog", linthresh=1e-4, linscale=0.1)
        ax.set_xlim(-0.5, _grid_sample_counts[sample_type].max())
        ax.set_xticks(d3.xpos)
        ax.set_xticklabels(d3.full_label)
        
        # Plot stacked barplot
        ax1 = ax.twinx()
        top_last = 0
        for strain in _strain_order:
            ax1.bar(
                x='xpos',
                height=strain,
                data=d3,
                bottom=top_last,
                width=bar_width,
                alpha=1.0,
                color=strain_palette[strain],
                edgecolor="k",
                lw=1,
                label='__nolegend__',
            )
            top_last += d3[strain]
            ax.scatter([], [], color=strain_palette[strain], label=strain, marker='s', s=80)
        ax1.set_yticks([])
        # Put strains behind points:
        ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
        ax.patch.set_visible(False)  # hide the 'canvas'
        ax1.patch.set_visible(True)  # show the 'canvas'

        ax1.set_ylim(0, 1)
        lib.plot.rotate_xticklabels(ax=ax)
        if sample_type == 'mouse':
            ax.legend(bbox_to_anchor=(1, 1))

In [None]:
d0 = (
    sample.loc[
        lambda x: (
            True
            # x.index.isin(sf_fit.sample.values)
            & x.sample_type.isin(["human", "Fermenter", "mouse"])
            & x.subject_id.isin(["A", "B", "H"])
            # & (x.sample_type == "Fermenter")
        ),
        [
            "subject_id",
            "collection_date_relative_een_end",
            "sample_type",
            "diet_or_media",
            "mouse_genotype",
            "source_samples",
            "status_mouse_inflamed",
            "full_label",
        ],
    ]
    .sort_values(
        [
            "subject_id",
            "collection_date_relative_een_end",
            "sample_type",
            "diet_or_media",
            "mouse_genotype",
            "source_samples",
            "status_mouse_inflamed",
        ]
    )
    .assign(
        rabund=motu_rabund[motu_id],
    )
    .dropna(subset=['rabund'])
)

subject_id = 'H'
sample_type = "human"
fig, ax = plt.subplots()

d1 = d0[lambda x: (x.subject_id == subject_id)]
strain_frac_sample_list = list(set(d1.index) & set(sf_fit.sample.values))
if len(strain_frac_sample_list) == 0:
    print(f"No strain analysis for {subject_id}.")
    comm = []
    _strain_order = []
else:
    w = (
        sf_fit.sel(sample=strain_frac_sample_list)
        # .drop_low_abundance_strains(drop_strains_thresh)
        .rename_coords(strain=str)
    )
    _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
    comm = w.community.to_pandas()
d2 = d1.join(comm)

d3 = d2[lambda x: (x.sample_type == sample_type)].assign(xpos=lambda x: np.arange(len(x.index)))
ax.scatter("xpos", "rabund", data=d3, color='k', s=10, label='__nolegend__')
# ax.set_aspect(700, adjustable="datalim", anchor="NW")
ax.set_ylim(-1e-3, 1)
ax.set_yscale("symlog", linthresh=1e-4, linscale=0.1)
ax.set_xlim(-0.5, _grid_sample_counts[sample_type].max())
ax.set_xticks(d3.xpos)
ax.set_xticklabels(d3.full_label)

# Plot stacked barplot
ax1 = ax.twinx()
top_last = 0
for strain in _strain_order:
    ax1.bar(
        x='xpos',
        height=strain,
        data=d3,
        bottom=top_last,
        width=bar_width,
        alpha=1.0,
        color=strain_palette[strain],
        edgecolor="k",
        lw=1,
        label='__nolegend__',
    )
    top_last += d3[strain]
    ax.scatter([], [], color=strain_palette[strain], label=strain, marker='s', s=80)
ax1.set_yticks([])
# Put strains behind points:
ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
ax.patch.set_visible(False)  # hide the 'canvas'
ax1.patch.set_visible(True)  # show the 'canvas'

ax1.set_ylim(0, 1)
lib.plot.rotate_xticklabels(ax=ax)

In [None]:
def pair_classifier(sample_typeA, sample_typeB):
    return ":".join(sorted(set([sample_typeA, sample_typeB])))


def construct_turnover_analysis_data(
    dmat,
    meta,
    sample_type_var,
    stratum_var=None,
    time_var=None,
):
    var_list = []
    for var in [sample_type_var, stratum_var, time_var]:
        if var is not None:
            var_list.append(var)
    meta = meta.reindex(dmat.index)[var_list].dropna()
    dmat = dmat.loc[meta.index, meta.index]
    data = []
    for (i, idxA), (j, idxB) in product(enumerate(meta.index), repeat=2):
        data.append(
            [
                idxA,
                idxB,
                meta.loc[idxA, stratum_var],
                meta.loc[idxB, stratum_var],
                meta.loc[idxA, time_var],
                meta.loc[idxB, time_var],
                meta.loc[idxA, sample_type_var],
                meta.loc[idxB, sample_type_var],
                dmat.loc[idxA, idxB],
            ]
        )
    data = pd.DataFrame(
        data,
        columns=[
            "sampleA",
            "sampleB",
            "stratumA",
            "stratumB",
            "timeA",
            "timeB",
            "sample_typeA",
            "sample_typeB",
            "diss",
        ],
    ).assign(
        pair_type=lambda x: x.apply(
            lambda y: pair_classifier(y.sample_typeA, y.sample_typeB), axis=1
        ),
        time_delta=lambda x: np.abs(x.timeB - x.timeA),
    )[
        lambda x: (x.stratumA == x.stratumB) & (x.sampleA < x.sampleB)
    ].assign(stratum=lambda x: x.stratumA)
    return data

In [None]:


for species_id in ['100022', '102506', '102544', '100099']:
    sf_fit = (
        sf.data.World.load(
            f"data/group/een/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.world.nc"
        )
        .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
        .rename_coords(sample={"CF_11": "CF_15", "CF_15": "CF_11"})
        # .drop_low_abundance_strains(0.01)
        # .rename_coords(strain=str)
    )
    
    comm = sf_fit.community.to_pandas()
    comm_bc_pdist = pd.DataFrame(squareform(pdist(comm, metric='braycurtis')), index=comm.index, columns=comm.index)
    
    turnover_data = construct_turnover_analysis_data(
        comm_bc_pdist,
        meta=sample[lambda x: x.diet_or_media.isin(['EEN', 'PostEEN'])],
        sample_type_var="diet_or_media",
        stratum_var="subject_id",
        time_var="collection_date_relative_een_end",
    )
    
    formula = "diss ~ 0 + pair_type + cr(time_delta, 4) + C(stratum, Sum)"
    fit = smf.rlm(formula, data=turnover_data).fit()
    
    print(species_id, fit.params['pair_type[EEN]'], fit.params['pair_type[EEN:PostEEN]'], fit.params['pair_type[EEN:PostEEN]'])
    
    # for pair_type, d1 in turnover_data.groupby('pair_type'):
    #     plt.scatter('time_delta', 'diss', data=d1, label=pair_type)
    # plt.xscale('log')
    # plt.legend()
    # fit.summary()

In [None]:
d0 = (
    sample.loc[
        lambda x: (
            True
            # x.index.isin(sf_fit.sample.values)
            & x.sample_type.isin(["human"])
            & x.subject_id.isin(subject_order)
            # & (x.sample_type == "Fermenter")
        ),
        [
            "subject_id",
            "collection_date_relative_een_end",
            "sample_type",
            "diet_or_media",
            "mouse_genotype",
            "source_samples",
            "status_mouse_inflamed",
            "full_label",
        ],
    ]
    .sort_values(
        [
            "subject_id",
            "collection_date_relative_een_end",
            "sample_type",
            "diet_or_media",
            "mouse_genotype",
            "source_samples",
            "status_mouse_inflamed",
        ]
    )
    .assign(
        rabund=motu_rabund[motu_id],
    )
    .dropna(subset=['rabund'])
)

_grid_sample_counts = (
    d0[["subject_id", "sample_type"]]
    .value_counts()
    .unstack()
    .reindex(columns=sample_type_order)
)
fig, axs = plt.subplots(
    *_grid_sample_counts.shape,
    figsize=(20, 15),
    width_ratios=_grid_sample_counts.max().values,
    sharey=True,
    gridspec_kw=dict(wspace=0.1, hspace=3)
)

for subject_id, ax_row in zip(subject_order, axs):
    d1 = d0[lambda x: (x.subject_id == subject_id)]
    strain_frac_sample_list = list(set(d1.index) & set(sf_fit.sample.values))
    if len(strain_frac_sample_list) == 0:
        print(f"No strain analysis for {subject_id}.")
        comm = []
        _strain_order = []
    else:
        w = (
            sf_fit.sel(sample=strain_frac_sample_list)
            .drop_low_abundance_strains(drop_strains_thresh)
            .rename_coords(strain=str)
        )
        _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
        comm = w.community.to_pandas()
    d2 = d1.join(comm)
    for sample_type, ax in zip(sample_type_order, ax_row):
        d3 = d2[lambda x: (x.sample_type == sample_type)].assign(xpos=lambda x: np.arange(len(x.index)))
        ax.scatter("xpos", "rabund", data=d3, color='k', s=10, label='__nolegend__')
        # ax.set_aspect(700, adjustable="datalim", anchor="NW")
        ax.set_ylim(-1e-3, 1)
        ax.set_yscale("symlog", linthresh=1e-4, linscale=0.1)
        ax.set_xlim(-0.5, _grid_sample_counts[sample_type].max())
        ax.set_xticks(d3.xpos)
        ax.set_xticklabels(d3.full_label)
        
        # Plot stacked barplot
        ax1 = ax.twinx()
        top_last = 0
        for strain in _strain_order:
            ax1.bar(
                x='xpos',
                height=strain,
                data=d3,
                bottom=top_last,
                width=bar_width,
                alpha=1.0,
                color=strain_palette[strain],
                edgecolor="k",
                lw=1,
                label='__nolegend__',
            )
            top_last += d3[strain]
            ax.scatter([], [], color=strain_palette[strain], label=strain, marker='s', s=80)
        ax1.set_yticks([])
        # Put strains behind points:
        ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
        ax.patch.set_visible(False)  # hide the 'canvas'
        ax1.patch.set_visible(True)  # show the 'canvas'

        ax1.set_ylim(0, 1)
        lib.plot.rotate_xticklabels(ax=ax)
        if sample_type == 'mouse':
            ax.legend(bbox_to_anchor=(1, 1))