# Preamble

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

## Imports

In [None]:
import os
import subprocess
import time
from itertools import chain, product
from tempfile import mkstemp
from warnings import filterwarnings

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import sfacts as sf
import statsmodels.formula.api as smf
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable

# from fastcluster import linkage
from scipy.cluster.hierarchy import linkage
from scipy.spatial.distance import pdist, squareform, cdist
from statsmodels.graphics.regressionplots import influence_plot
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
import lib.thisproject.data
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

In [None]:
def linkage_order(linkage, labels):
    return labels[sp.cluster.hierarchy.to_tree(linkage).pre_order(lambda x: x.id)]

def label_een_experiment_sample(x):
    if x.sample_type == "human":
        label = f"[{x.name}] {x.collection_date_relative_een_end} {x.diet_or_media}"
    elif x.sample_type in ["Fermenter_inoculum"]:
        label = f"[{x.name}] {x.source_samples} inoc {x.diet_or_media}"
    elif x.sample_type in ["Fermenter"]:
        # label = f"[{x.name}] {x.source_samples} frmnt {x.diet_or_media}"
        label = f"[{x.name}] {x.source_samples} {x.diet_or_media}"
    elif x.sample_type in ["mouse"]:
        if x.status_mouse_inflamed == 'Inflamed':
            # label = f"[{x.name}] {x.source_samples} 🐭 {x.mouse_genotype} {x.diet_or_media} inflam"
            label = f"[{x.name}] {x.source_samples} {x.diet_or_media} inflam"
        elif x.status_mouse_inflamed == 'not_Inflamed':
            # label = f"[{x.name}] {x.source_samples} 🐭 {x.mouse_genotype} {x.diet_or_media} not_inf"
            label = f"[{x.name}] {x.source_samples} {x.diet_or_media} not_inf"
        else:
            raise ValueError(f"sample type {x.status_mouse_inflamed} not understood")
    else:
        raise ValueError(f"sample type {x.sample_type} not understood")
    return label

In [None]:
sns.set_context("talk")
plt.rcParams["figure.dpi"] = 50

# Prepare Metadata

In [None]:
subject_order = [
    "A",
    "B",
    "H",
    "C",
    "D",
    "E",
    "F",
    "G",
    "K",
    "L",
    "M",
    "N",
    "O",
    "P",
    "Q",
    "R",
    "S",
    "T",
    "U",
]

In [None]:
sample = (
    pd.read_table("meta/een-mgen/sample.tsv")
    .assign(
        label=lambda x: x[
            ["collection_date_relative_een_end", "diet_or_media", "sample_id"]
        ].apply(tuple, axis=1)
    )
    .set_index("sample_id")
    .assign(full_label=lambda d: d.apply(label_een_experiment_sample, axis=1))
)

# Prepare Data

In [None]:
rotu_counts = pd.read_table(
    "data/group/een/a.proc.zotu_counts.tsv", index_col="#OTU ID"
).rename_axis(index="zotu", columns="sample_id")
rotu_taxonomy = rotu_counts.taxonomy
rotu_counts = rotu_counts.drop(columns=["taxonomy"]).T
rotu_rabund = rotu_counts.divide(rotu_counts.sum(1), axis=0)

sample_rotu_bc_linkage = sp.cluster.hierarchy.linkage(
    rotu_rabund, method="average", metric="braycurtis", optimal_ordering=True
)

In [None]:
motu_depth = (pd.read_table(
    "data/group/een/r.proc.gene99_new-v22-agg75.spgc_specgene-ref-t25-p95.species_depth.tsv",
    names=['sample', "species_id", 'depth'], index_col=['sample', "species_id"],
    )
    .depth.unstack(fill_value=0)
    .rename(columns=str, index=lambda x: "CF_" + str(int(x.split("_")[1])))
    .rename({'CF_15': 'CF_11', 'CF_11': 'CF_15'})  # Sample swap
)
motu_rabund = motu_depth.divide(motu_depth.sum(1), axis=0)

motu_rabund

# Species Enrichment Analysis

# Species Enrichment Analysis

In [None]:
def enrichment_test(d):
    try:
        res = sp.stats.wilcoxon(d["EEN"], d["PostEEN"])
    except ValueError:
        res = (np.nan, np.nan)
    log2_ratio = np.log2(d["PostEEN"] / d["EEN"])
    return pd.Series(
        [log2_ratio.mean(), d["EEN"].mean(), d["PostEEN"].mean(), res[1]],
        index=["log2_ratio", "mean_EEN", "mean_PostEEN", "pvalue"],
    )

In [None]:
motu_enrichment_results = (
    motu_rabund.apply(lambda x: x + x.replace({0: np.inf}).min())
    .join(sample[["subject_id", "diet_or_media"]])
    .groupby(["subject_id", "diet_or_media"])
    .mean()
    .stack()
    .unstack("diet_or_media")[["EEN", "PostEEN"]]
    .dropna()
    .assign(log2_ratio=lambda x: np.log2(x["PostEEN"] / x["EEN"]))
    .rename_axis(index=["subject_id", "motu_id"])
    .groupby(level="motu_id")
    .apply(enrichment_test)
)
motu_enrichment_results.sort_values("pvalue", ascending=True).head(20)
# fig, ax = plt.subplots()
# print(d.log2_ratio.mean())
# print(sp.stats.wilcoxon(d['PostEEN'], d['EEN']))
# ax.hist(d.log2_ratio, bins=20)

# Strain Time-series

In [None]:
def pair_classifier(sample_typeA, sample_typeB):
    return ":".join(sorted(set([sample_typeA, sample_typeB])))


def construct_turnover_analysis_data(
    dmat,
    meta,
    sample_type_var,
    stratum_var=None,
    time_var=None,
):
    var_list = []
    for var in [sample_type_var, stratum_var, time_var]:
        if var is not None:
            var_list.append(var)
    meta = meta.reindex(dmat.index)[var_list].dropna()
    dmat = dmat.loc[meta.index, meta.index]
    data = []
    for (i, idxA), (j, idxB) in product(enumerate(meta.index), repeat=2):
        data.append(
            [
                idxA,
                idxB,
                meta.loc[idxA, stratum_var],
                meta.loc[idxB, stratum_var],
                meta.loc[idxA, time_var],
                meta.loc[idxB, time_var],
                meta.loc[idxA, sample_type_var],
                meta.loc[idxB, sample_type_var],
                dmat.loc[idxA, idxB],
            ]
        )
    data = pd.DataFrame(
        data,
        columns=[
            "sampleA",
            "sampleB",
            "stratumA",
            "stratumB",
            "timeA",
            "timeB",
            "sample_typeA",
            "sample_typeB",
            "diss",
        ],
    ).assign(
        pair_type=lambda x: x.apply(
            lambda y: pair_classifier(y.sample_typeA, y.sample_typeB), axis=1
        ),
        time_delta=lambda x: np.abs(x.timeB - x.timeA),
    )[
        lambda x: (x.stratumA == x.stratumB) & (x.sampleA < x.sampleB)
    ].assign(stratum=lambda x: x.stratumA)
    return data

In [None]:
species_id = '101378'
inpath = f"data/group/een/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.world.nc"
sf_fit = (
    sf.data.World.load(
        inpath
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={"CF_11": "CF_15", "CF_15": "CF_11"})
    # .drop_low_abundance_strains(0.01)
    # .rename_coords(strain=str)
)

comm = sf_fit.community.to_pandas()
comm_bc_pdist = pd.DataFrame(squareform(pdist(comm, metric='braycurtis')), index=comm.index, columns=comm.index)

turnover_data = construct_turnover_analysis_data(
    comm_bc_pdist,
    meta=sample[lambda x: x.diet_or_media.isin(['EEN', 'PostEEN'])],
    sample_type_var="diet_or_media",
    stratum_var="subject_id",
    time_var="collection_date_relative_een_end",
)

formula = "diss ~ 0 + pair_type + cr(time_delta, 4) + C(stratum, Sum)"
fit = smf.ols(formula, data=turnover_data).fit()

coef_list = []
for coef in ['pair_type[EEN]', 'pair_type[PostEEN]', 'pair_type[EEN:PostEEN]']:
    if coef in fit.params:
        coef_list.append(fit.params[coef])
    else:
        coef_list.append(np.nan)        

for pair_type, d1 in turnover_data.groupby('pair_type'):
    plt.scatter('time_delta', 'diss', data=d1, label=pair_type)
plt.xscale('log')
plt.legend()
fit.summary()


In [None]:
results = []

for species_id in motu_enrichment_results.index:
    inpath = f"data/group/een/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.world.nc"
    if not os.path.exists(inpath):
        print(f"No sfacts fit for {species_id}.")
        continue
    sf_fit = (
        sf.data.World.load(
            inpath
        )
        .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
        .rename_coords(sample={"CF_11": "CF_15", "CF_15": "CF_11"})
        # .drop_low_abundance_strains(0.01)
        # .rename_coords(strain=str)
    )
    
    comm = sf_fit.community.to_pandas()
    comm_bc_pdist = pd.DataFrame(squareform(pdist(comm, metric='braycurtis')), index=comm.index, columns=comm.index)
    
    turnover_data = construct_turnover_analysis_data(
        comm_bc_pdist,
        meta=sample[lambda x: x.diet_or_media.isin(['EEN', 'PostEEN'])],
        sample_type_var="diet_or_media",
        stratum_var="subject_id",
        time_var="collection_date_relative_een_end",
    )
    if turnover_data.empty:
        print(f"No data for {species_id}.")
        continue
    
    formula = "diss ~ 0 + pair_type + cr(time_delta, 4) + C(stratum, Sum)"
    try:
        fit = smf.ols(formula, data=turnover_data).fit()
    except ZeroDivisionError:
        print(f"OLS failed for {species_id}.")
        continue
    except ValueError:
        print(f"OLS failed for {species_id}.")
        continue

    coef_list = []
    for coef in ['pair_type[EEN]', 'pair_type[PostEEN]', 'pair_type[EEN:PostEEN]']:
        if coef in fit.params:
            coef_list.append(fit.params[coef])
        else:
            coef_list.append(np.nan)        
    results.append((species_id, *coef_list, turnover_data.shape[0]))
    
    # for pair_type, d1 in turnover_data.groupby('pair_type'):
    #     plt.scatter('time_delta', 'diss', data=d1, label=pair_type)
    # plt.xscale('log')
    # plt.legend()
    # fit.summary()
results = pd.DataFrame(results, columns=['motu_id', 'EEN', 'PostEEN', 'Transition', 'num_pairs'])

In [None]:
d = motu_enrichment_results.join(results.set_index('motu_id')).assign(
    relative_transition=lambda x: x['Transition'] - x['EEN'],
    signif_enrich=lambda w: w.pvalue < 0.05,
).dropna(subset=['relative_transition'])

plt.scatter('mean_EEN', 'relative_transition', data=d, c='signif_enrich')
plt.yscale('symlog', linthresh=1e-2)
plt.xscale('log')

In [None]:
d[lambda x: (x.mean_EEN > 1e-3) & (x.relative_transition > 1e-1)]

In [None]:
d[lambda x: (x.mean_EEN > 1e-3)].sort_values('relative_transition', ascending=False).head(20)

In [None]:
sp.stats.spearmanr(d['log2_ratio'], d['relative_transition'])

In [None]:
d.sort_values('relative_transition', ascending=False).head(50)

In [None]:
sf_fit = (
    sf.data.World.load(
        inpath
    )
    .rename_coords(sample=lambda s: "CF_{}".format(int(s.split("_")[1])))
    .rename_coords(sample={"CF_11": "CF_15", "CF_15": "CF_11"})
)

subject_id = 'A'
d = sample[lambda x: x.subject_id == subject_id]


else:
        w = (
            sf_fit.sel(sample=strain_frac_sample_list)
            .drop_low_abundance_strains(drop_strains_thresh)
            .rename_coords(strain=str)
        )
        _strain_order = [s for s in strain_order if s in w.strain] + ["-1"]
        comm = w.community.to_pandas()
    d2 = d1.join(comm)
    for sample_type, ax in zip(sample_type_order, ax_row):
        d3 = d2[lambda x: (x.sample_type == sample_type)].assign(xpos=lambda x: np.arange(len(x.index)))
        ax.scatter("xpos", "rabund", data=d3, color='k', s=10, label='__nolegend__')
        # ax.set_aspect(700, adjustable="datalim", anchor="NW")
        ax.set_ylim(-1e-3, 1)
        ax.set_yscale("symlog", linthresh=1e-4, linscale=0.1)
        ax.set_xlim(-0.5, _grid_sample_counts[sample_type].max())
        ax.set_xticks(d3.xpos)
        ax.set_xticklabels(d3.full_label)
        
        # Plot stacked barplot
        ax1 = ax.twinx()
        top_last = 0
        for strain in _strain_order:
            ax1.bar(
                x='xpos',
                height=strain,
                data=d3,
                bottom=top_last,
                width=bar_width,
                alpha=1.0,
                color=strain_palette[strain],
                edgecolor="k",
                lw=1,
                label='__nolegend__',
            )
            top_last += d3[strain]
            ax.scatter([], [], color=strain_palette[strain], label=strain, marker='s', s=80)
        ax1.set_yticks([])
        # Put strains behind points:
        ax.set_zorder(ax1.get_zorder() + 1)  # put ax in front of ax1
        ax.patch.set_visible(False)  # hide the 'canvas'
        ax1.patch.set_visible(True)  # show the 'canvas'

        ax1.set_ylim(0, 1)
        lib.plot.rotate_xticklabels(ax=ax)
        if sample_type == 'mouse':
            ax.legend(bbox_to_anchor=(1, 1))