# Preamble

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

## Imports

In [None]:
import os
import subprocess
import time
from itertools import chain, product
from tempfile import mkstemp
from warnings import filterwarnings

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import sfacts as sf
import statsmodels.formula.api as smf
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable

# from fastcluster import linkage
from scipy.cluster.hierarchy import linkage
from scipy.spatial.distance import pdist, squareform
from statsmodels.graphics.regressionplots import influence_plot
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.dissimilarity
import lib.plot
import lib.thisproject.data
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

In [None]:
sns.set_context("talk")
plt.rcParams["figure.dpi"] = 100

In [None]:
def _calculate_2tailed_pvalue_from_perm(obs, perms):
    hypoth_left = perms > obs
    hypoth_right = perms < obs
    null_p_left = (hypoth_left.sum() + 1) / (len(hypoth_left) + 1)
    null_p_right = (hypoth_right.sum() + 1) / (len(hypoth_right) + 1)
    return np.minimum(null_p_left, null_p_right) * 2

In [None]:
def linkage_order(linkage, labels):
    return labels[sp.cluster.hierarchy.to_tree(linkage).pre_order(lambda x: x.id)]


def is_prime(n):
    if n <= 1:
        return False
    for i in range(2, int(n**0.5) + 1):
        if n % i == 0:
            return False
    return True


def iterate_primes_up_to(n, return_index=False):
    n = int(np.ceil(n))
    idx = 0
    for i in range(n):
        if is_prime(i):
            if return_index:
                yield (idx, i)
            else:
                yield i
            idx += 1


def maximally_shuffled_order(sorted_order):
    n = len(sorted_order)
    primes_list = list(iterate_primes_up_to(np.sqrt(n)))
    table = pd.DataFrame(np.arange(n), index=sorted_order, columns=["original_order"])
    for prime in primes_list:
        table[prime] = table.original_order % prime
    table.sort_values(primes_list).original_order.values
    table = table.assign(new_order=table.sort_values(primes_list).original_order.values)
    z = table.sort_values("new_order").original_order.values
    table["delta"] = [np.nan] + list(z[1:] - z[:-1])
    return table.sort_values("new_order").index.to_list()

# Construct Metadata

In [None]:
cog_meta = pd.read_table(
    "ref/cog-20.meta.tsv",
    encoding="latin10",
    names=["cog", "categories", "description", "gene_name", "pathway", "_5", "color"],
    index_col="cog",
)
cog_meta

In [None]:
cog_category_meta = pd.read_table(
    "ref/cog-20.categories.tsv",
    names=["category", "color", "description"],
    index_col="category",
)
cog_category_meta

In [None]:
pair_type_palette = {
    "EEN": "teal",
    "PostEEN": "mediumblue",
    "Transition": "blueviolet",
}

diet_palette = {
    "EEN": "lightgreen",
    "PostEEN": "lightblue",
    "InVitro": "plum",
    "PreEEN": "lightpink",
}

subject_order = [
    "A",
    "B",
    "H",
    "C",
    "D",
    "E",
    "F",
    "G",
    "K",
    "L",
    "M",
    "N",
    "O",
    "P",
    "Q",
    "R",
    "S",
    "T",
    "U",
]

# NOTE: Requires a dummy value because I want exactly 20 items.
subject_palette = lib.plot.construct_ordered_palette(
    subject_order + [f"dummy{i}" for i in range(20 - len(subject_order))], cm="tab20"
)
subject_palette["X"] = "black"
pair_type_order = ["EEN", "Transition", "PostEEN"]
pair_type_marker_palette = {"EEN": "s", "Transition": ">", "PostEEN": "o"}
pair_type_linestyle_palette = {"EEN": ":", "Transition": "-.", "PostEEN": "-"}

In [None]:
def _label_experiment_sample(x):
    if x.sample_type == "human":
        label = f"{x.subject_id} [{x.sample_id}] {x.collection_date_relative_een_end} {x.diet_or_media}"
    elif x.sample_type in ["Fermenter_inoculum"]:
        label = (
            f"{x.subject_id} [{x.sample_id}] {x.source_samples} inoc {x.diet_or_media}"
        )
    elif x.sample_type in ["Fermenter"]:
        label = (
            f"{x.subject_id} [{x.sample_id}] {x.source_samples} frmnt {x.diet_or_media}"
        )
    elif x.sample_type in ["mouse"]:
        if x.status_mouse_inflamed == "Inflamed":
            label = f"{x.subject_id} [{x.sample_id}] {x.source_samples} 🐭 {x.mouse_genotype} {x.diet_or_media} inflam"
        elif x.status_mouse_inflamed == "not_Inflamed":
            label = f"{x.subject_id} [{x.sample_id}] {x.source_samples} 🐭 {x.mouse_genotype} {x.diet_or_media} not_inf"
        else:
            raise ValueError(f"sample type {x.status_mouse_inflamed} not understood")
    else:
        raise ValueError(f"sample type {x.sample_type} not understood")
    return label

In [None]:
sample = (
    pd.read_table("meta/een-mgen/sample.tsv")
    .assign(
        label=lambda x: x[
            [
                "subject_id",
                "collection_date_relative_een_end",
                "diet_or_media",
                "sample_id",
            ]
        ].apply(tuple, axis=1),
        fuller_label=lambda d: d.apply(_label_experiment_sample, axis=1),
    )
    .set_index("sample_id")
)
subject = pd.read_table("meta/een-mgen/subject.tsv", index_col="subject_id")

In [None]:
# Ubiquitous, single-copy genes to be used for estimating total genome depth:

schg_cog_list = [
    "COG0012",
    "COG0016",
    "COG0048",
    "COG0049",
    "COG0052",
    "COG0080",
    "COG0081",
    "COG0085",
    "COG0087",
    "COG0088",
    "COG0090",
    "COG0091",
    "COG0092",
    "COG0093",
    "COG0094",
    "COG0096",
    "COG0097",
    "COG0098",
    "COG0099",
    "COG0100",
    "COG0102",
    "COG0103",
    "COG0124",
    "COG0184",
    "COG0185",
    "COG0186",
    "COG0197",
    "COG0200",
    "COG0201",
    "COG0256",
    "COG0495",
    "COG0522",
    "COG0525",
    "COG0533",
    "COG0542",  # This one is a depth outlier...
]

cog_meta.loc[schg_cog_list]

# Load Data

## Species

In [None]:
# List of all species with pangenome profiles

species_list = pd.read_table("meta/species_group.tsv")[
    lambda x: x.species_group_id == "een"
].species_id
assert species_list.is_unique
species_list = list(species_list)

## Genes

In [None]:
# Load table of gene depths for each species and aggregate by COG.
# Can take up to 8 minutes to compile everything
cog_depth = {}

for species in tqdm(species_list):
    gene_x_cog_inpath = (
        f"data/species/sp-{species}/pangenome.centroids.emapper.gene_x_cog.tsv"
    )
    # f"data/species/sp-{species}/midasdb_v15.emapper.gene75_x_cog.tsv"  # This is new as of 2023-12-06 and used the "voting" procedure. Also, it's based on MIDASDB v1.5... so that might be problematic.
    gene_depth_inpath = (
        f"data/group/een/species/sp-{species}/r.proc.gene99_new-v22-agg75.depth2.nc"
    )
    _gene_x_cog = (
        pd.read_table(gene_x_cog_inpath)
        .drop_duplicates()
        .set_index("gene_id")
        .squeeze()
    )

    # Calculate the depth of each COG by summing all genes labeled as that COG.
    _cog_depth = (
        xr.load_dataarray(gene_depth_inpath)
        .to_pandas()
        .T.join(_gene_x_cog)
        .groupby("cog")
        .sum()
    )
    cog_depth[species] = _cog_depth.stack()

cog_depth = (
    pd.DataFrame(cog_depth)
    .stack()
    .rename_axis(["cog", "sample", "species"])
    .to_xarray()
    .fillna(0)
)

# Normalize sample names and swap the mislabeled samples.
cog_depth["sample"] = (
    cog_depth.sample.to_series()
    .map(lambda x: "CF_" + str(int(x.split("_")[1])))
    .replace({"CF_15": "CF_11", "CF_11": "CF_15"})
    .values
)

### Validate SCHGs

In [None]:
schg_cog_by_sample_depth = cog_depth.sum("species").sel(cog=schg_cog_list).to_pandas()

sns.clustermap(
    schg_cog_by_sample_depth,
    metric="cosine",
    yticklabels=1,
)

In [None]:
pairwise_schg_cog_correlation = 1 - lib.dissimilarity.dmatrix(
    schg_cog_by_sample_depth, metric="correlation"
)
sns.clustermap(
    pairwise_schg_cog_correlation, norm=mpl.colors.PowerNorm(1, vmin=0, vmax=1)
)

## Species Depth

In [None]:
species_depth = cog_depth.sel(cog=schg_cog_list).mean("cog")
species_relabund = species_depth / species_depth.sum("species")

### Validate Species Depth Estimates

In [None]:
_species_list = ["100003", "102506", "100022"]
_sample_list = ["CF_94", "CF_93"]

fig, axs = plt.subplots(
    len(_species_list),
    len(_sample_list),
    figsize=(5 * len(_sample_list), 3 * len(_species_list)),
    sharex=True,
    sharey=True,
)
axs = np.asanyarray(axs).reshape((len(_species_list), len(_sample_list)))

bins = np.logspace(-3, 4, num=100)

for (_species, _sample), ax in zip(product(_species_list, _sample_list), axs.flatten()):
    ax.hist(cog_depth.sel(sample=_sample, species=_species).to_pandas(), bins=bins)
    ax.axvline(species_depth.loc[_sample, _species], color="black")
    ax.set_title((_species, _sample))

ax.set_xscale("log")
# plt.xlim(0, 200)

## Detection Limit Imputation

In [None]:
cog_detection_limit = cog_depth.where(lambda x: x != 0, np.inf).min(
    ("sample", "species")
)
undetected_cogs_list = idxwhere((cog_detection_limit == np.inf).to_series())

cog_depth_or_detection_limit = cog_depth.where(
    lambda x: x != 0, cog_detection_limit
).drop_sel(cog=undetected_cogs_list)
cog_depth_or_detection_limit

## Normalize COG depth

In [None]:
total_genome_depth = (
    cog_depth_or_detection_limit.sel(cog=schg_cog_list)
    .median("cog")
    .sum("species")  # NOTE: Mean or Median? Does it matter?
)
normalized_cog_depth_by_sample = (
    cog_depth_or_detection_limit.sum("species") / total_genome_depth
)

## Aggregate by Subject

In [None]:
normalized_cog_depth_by_subject_and_type = (
    normalized_cog_depth_by_sample.to_pandas()
    .T.join(sample[["subject_id", "diet_or_media"]])[
        lambda x: x.diet_or_media.isin(["EEN", "PostEEN"])
    ]
    .groupby(["subject_id", "diet_or_media"])
    .median()  # NOTE: Mean or Median?
    .unstack("diet_or_media")
    .dropna()
    .stack("diet_or_media")
)
normalized_cog_depth_by_subject_and_type

## Perform Tests

In [None]:
pairwise_test_results = {}
for cog in tqdm(normalized_cog_depth_by_subject_and_type.columns):
    d = normalized_cog_depth_by_subject_and_type[cog].unstack()
    mean_een = d.EEN.mean()
    mean_post = d.PostEEN.mean()
    mean_log2_ratio = np.log2(d.PostEEN / d.EEN).mean()
    median_log2_ratio = np.log2(d.PostEEN / d.EEN).median()
    try:
        result = sp.stats.wilcoxon(
            d.EEN,
            d.PostEEN,
        )
        pval = result.pvalue
    except ValueError:
        pval = np.nan
    pairwise_test_results[cog] = (
        mean_een,
        mean_post,
        mean_log2_ratio,
        median_log2_ratio,
        pval,
    )

pairwise_test_results = pd.DataFrame(
    pairwise_test_results,
    index=("mean_een", "mean_post", "mean_log2_ratio", "median_log2_ratio", "pval"),
).T

## Calculate FDR

In [None]:
# Here is where I define filters on COGs:
#    They must have a mean depth during one of the two time-periods of > 0.01

pairwise_test_results_filt_with_fdr = (
    pairwise_test_results
    # [lambda x: (x.mean_een > 0.01) | (x.mean_post > 0.01)]
    .assign(
        fdr=lambda x: fdrcorrection(x.pval)[1],
        hit=lambda x: (
            True
            & (x.fdr < 0.1)
            # & (np.abs(x.mean_log2_ratio) > 0.2)
        ),
    ).sort_values("fdr")
)
pairwise_test_results_filt_with_fdr[lambda x: x.hit]

In [None]:
pairwise_test_results_filt_with_fdr[lambda x: x.hit].join(cog_meta).sort_values(
    "median_log2_ratio", ascending=False
).head(20)

## Visualize

In [None]:
d = pairwise_test_results_filt_with_fdr.sort_values("pval").join(cog_meta)

fig, ax = plt.subplots()
ax.scatter("mean_log2_ratio", "pval", data=d[d.hit], color="r", s=5)
ax.scatter("mean_log2_ratio", "pval", data=d[~d.hit], color="grey", s=5)
ax.invert_yaxis()
ax.set_yscale("log")
ax.axvline(0.2, color="black", lw=1, linestyle="--")
ax.axvline(-0.2, color="black", lw=1, linestyle="--")
ax.axhline(0.05, color="black", lw=1, linestyle="--")
ax.set_xlabel("Mean Log2(OR)")
ax.set_ylabel("P-value")

In [None]:
# Visualize time-series for one, top COG.
cog = "COG5599"

d = (
    sample.assign(cog=normalized_cog_depth_by_sample.sel(cog=cog).to_series())
    .sort_values("collection_date_relative_een_end")
    .dropna(subset=["cog"])
)

for subject in subject_order:
    plt.plot(
        "collection_date_relative_een_end",
        "cog",
        data=d[lambda w: w.subject_id == subject],
    )
plt.xscale("symlog")
plt.xlabel("Days before/after EEN End")
plt.ylabel("Normalized COG abundance")

In [None]:
(pairwise_test_results_filt_with_fdr.fdr < 0.1).sum()

## COG-categories Enriched in Gene Hits

In [None]:
cog_x_cog_category_matrix = (
    cog_meta.categories.map(tuple)
    .explode()
    .rename("category")
    .reset_index()
    .assign(in_category=True)
    .set_index(["cog", "category"])
    .in_category.unstack(fill_value=False)
)

In [None]:
def test_enrichment(x, y, data):
    contingency = (
        data[[x, y]]
        .value_counts()
        .reindex(
            [(True, True), (True, False), (False, True), (False, False)], fill_value=0
        )
        .unstack()
    )
    contingency_pc = contingency + 1
    log2_odds_ratio_pc = np.log2(
        (contingency_pc.loc[True, True] / contingency_pc.loc[True, False])
        / (contingency_pc.loc[False, True] / contingency_pc.loc[False, False])
    )
    num_hit = contingency_pc.loc[True, True] - 1
    return (
        num_hit,
        log2_odds_ratio_pc,
        *sp.stats.fisher_exact(contingency, alternative="greater"),
    )

In [None]:
d = pairwise_test_results_filt_with_fdr.assign(
    increased=lambda x: x.hit & (x.median_log2_ratio > 0),
    decreased=lambda x: x.hit & (x.median_log2_ratio < 0),
).join(cog_x_cog_category_matrix)

results = []
for cog_category in cog_x_cog_category_matrix.columns:
    for direction in ["increased", "decreased"]:
        results.append(
            (cog_category, direction, *test_enrichment(direction, cog_category, data=d))
        )
results = pd.DataFrame(
    results,
    columns=[
        "cog_category",
        "direction",
        "num_hit",
        "log2_odds_ratio_pc",
        "fisher_stat",
        "pvalue",
    ],
)
results.sort_values("pvalue")