## Preamble

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

### Imports

In [None]:
import os
import subprocess
import time
from itertools import chain, product
from tempfile import mkstemp
from warnings import filterwarnings

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import sfacts as sf
import statsmodels.formula.api as smf
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable
from statsmodels.graphics.regressionplots import influence_plot
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
import lib.thisproject.data
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

In [None]:
def linkage_order(linkage, labels):
    labels = np.array(labels)
    return list(labels[sp.cluster.hierarchy.to_tree(linkage).pre_order(lambda x: x.id)])


def is_prime(n):
    if n <= 1:
        return False
    for i in range(2, int(n**0.5) + 1):
        if n % i == 0:
            return False
    return True


def iterate_primes_up_to(n, return_index=False):
    n = int(np.ceil(n))
    idx = 0
    for i in range(n):
        if is_prime(i):
            if return_index:
                yield (idx, i)
            else:
                yield i
            idx += 1


def maximally_shuffled_order(sorted_order):
    n = len(sorted_order)
    primes_list = list(iterate_primes_up_to(np.sqrt(n)))
    table = pd.DataFrame(np.arange(n), index=sorted_order, columns=["original_order"])
    for prime in primes_list:
        table[prime] = table.original_order % prime
    table.sort_values(primes_list).original_order.values
    table = table.assign(new_order=table.sort_values(primes_list).original_order.values)
    z = table.sort_values("new_order").original_order.values
    table["delta"] = [np.nan] + list(z[1:] - z[:-1])
    return table.sort_values("new_order").index.to_list()

### Set Style

In [None]:
sns.set_context("talk")
plt.rcParams["figure.dpi"] = 50

## Papermill parameters

In [None]:
# This cell is tagged "parameters" for papermill.
# See <https://papermill.readthedocs.io/en/latest/usage-parameterize.html#how-parameters-work> for some gotchas.
# NOTE: *ALL* parameters should be passed to papermill. Values set here are only for prototyping.
species_id = "100099"
sfacts_params = "filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts37-s85-seed0"
spgc_specgene_params = "specgene-ref-t25-p95"
spgc_paramsA = f"{spgc_specgene_params}_ss-all_t-30"
spgc_params = f"spgc_{spgc_paramsA}_thresh-corr200-depth250"
group = "een"
pangenome_params = "gene99_new-v22-agg75"

species_taxonomy_inpath = "ref/gtpro/species_taxonomy_ext.tsv"
sfacts_fit_inpath = (
    f"data/group/{group}/species/sp-{species_id}/r.proc.gtpro.{sfacts_params}.world.nc"
)
spgc_meta_inpath = f"data/group/{group}/species/sp-{species_id}/r.proc.gtpro.{sfacts_params}.{pangenome_params}.{spgc_params}.strain_meta.tsv"

ref_gene_copies_inpath = f"data/species/sp-{species_id}/gene75_new.reference_copy_number.nc"
spgc_hits_inpath = f"data/group/{group}/species/sp-{species_id}/r.proc.gtpro.{sfacts_params}.{pangenome_params}.{spgc_params}.strain_gene.tsv"
spgc_depth_inpath = f"data/group/{group}/species/sp-{species_id}/r.proc.gtpro.{sfacts_params}.{pangenome_params}.spgc_{spgc_paramsA}.strain_depth_ratio.tsv"
spgc_corr_inpath = f"data/group/{group}/species/sp-{species_id}/r.proc.gtpro.{sfacts_params}.{pangenome_params}.spgc_{spgc_paramsA}.strain_correlation.tsv"
sample_to_spgc_inpath = f"data/group/{group}/species/sp-{species_id}/r.proc.gtpro.{sfacts_params}.spgc_ss-all.strain_samples.tsv"
uhgg_x_eggnog_inpath = (
    f"data/species/sp-{species_id}/pangenome_new.centroids.emapper.gene_x_eggnog.tsv"
)
uhgg_x_top_eggnog_inpath = f"data/species/sp-{species_id}/pangenome_new.centroids.emapper.gene_x_top_eggnog.tsv"
uhgg_gene_length_inpath = (
    f"ref/midasdb_uhgg_new/pangenomes/{species_id}/cluster_info.txt"
)
gene_annotations_inpath = f"data/species/sp-{species_id}/pangenome_new.centroids.emapper.d/proteins.emapper.annotations"

# Sample Metadata
raw_depth_inpath = (
    f"data/group/{group}/species/sp-{species_id}/r.proc.{pangenome_params}.depth2.nc"
)
all_species_gtpro_depth_inpath = f"data/group/{group}/r.proc.gtpro.species_depth.tsv"
species_depth_inpath = f"data/group/{group}/species/sp-{species_id}/r.proc.{pangenome_params}.spgc_{spgc_specgene_params}.species_depth.tsv"
stool_inpath = "meta/een-mgen/stool.tsv"
subject_inpath = "meta/een-mgen/subject.tsv"
microcosm_inpath = "meta/een-mgen/microcosm.tsv"

# Outputs
html_outpath = f"data/group/{group}/species/sp-{species_id}/r.proc.gtpro.{sfacts_params}.{pangenome_params}.{spgc_params}.spgc_ref_comparison.html"

## Load Metadata

### Gene Metadata

In [None]:
uhgg_x_eggnog = pd.read_table(uhgg_x_eggnog_inpath)

In [None]:
uhgg_x_top_eggnog = pd.read_table(uhgg_x_top_eggnog_inpath)

In [None]:
uhgg_gene_length = (
    pd.read_table(uhgg_gene_length_inpath)
    .groupby("centroid_75")
    .centroid_99_length.mean()
)

In [None]:
eggnog_column_names = "query seed_ortholog evalue score eggNOG_OGs max_annot_lvl COG_category Description Preferred_name GOs EC KEGG_ko KEGG_inpathway KEGG_Module KEGG_Reaction KEGG_rclass BRITE KEGG_TC CAZy BiGG_Reaction PFAMs".split(
    " "
)
_gene_annotations = (
    pd.read_table(
        gene_annotations_inpath,
        comment="#",
        names=eggnog_column_names,
        index_col="query",
    )
    .rename_axis(index="gene_id")
    .replace({"-": np.nan})
)
_gene_annotations = uhgg_gene_length.to_frame().join(_gene_annotations)

_gene_annotations.info()

In [None]:
gene_x_cog_category1 = (
    _gene_annotations.COG_category.fillna("-").apply(list).explode()[lambda x: x != "-"]
)
gene_x_cog_category1

In [None]:
cog_x_category = pd.read_table(
    "ref/cog-20.meta.tsv",
    names=["cog", "cog_category", "description", "short_name", "_4", "_5", "_6"],
    index_col="cog",
).cog_category
cog_x_category

In [None]:
gene_x_ko = _gene_annotations.KEGG_ko.str.split(',').explode().dropna().str[len('ko:'):]
gene_x_ec = _gene_annotations.EC.str.split(',').explode().dropna()

In [None]:
gene_x_kmodule = _gene_annotations.KEGG_Module.str.split(',').explode().dropna()

In [None]:
gene_x_kpathway = _gene_annotations.KEGG_inpathway.str.split(',').explode().dropna()

In [None]:
gene_x_cog = (
    _gene_annotations.eggNOG_OGs.fillna("")
    .str.split(",")
    .explode()[lambda x: x.str.startswith("COG")]
    .str.split("@")
    .str[0]
)
gene_x_cog.value_counts().head()
gene_x_cog_category2 = gene_x_cog.map(cog_x_category).dropna().apply(list).explode()
gene_x_cog_category2

In [None]:
gene_x_cog_category = (
    pd.concat(
        [
            gene_x_cog_category1,
            gene_x_cog_category2,  # FIXME: Which metadata table do I want?
        ]
    )
    .reset_index()
    .drop_duplicates()
)

gene_x_cog_category.columns = ["centroid_75", "cog_category"]
gene_x_cog_category = gene_x_cog_category.set_index("centroid_75").cog_category
gene_x_cog_category.shape[0]

In [None]:
gene_annotations = _gene_annotations.assign(
    COG_category=gene_x_cog_category.sort_values()
    .reset_index()
    .groupby("centroid_75")
    .apply(lambda x: "".join(x.cog_category.values))
).assign(COG_category=lambda x: x.COG_category.fillna(""))

In [None]:
gene_x_cog_category_matrix = (
    gene_x_cog_category.reset_index()
    .assign(tally=True)
    .set_index(["centroid_75", "cog_category"])
    .tally.unstack("cog_category")
    .fillna(False)
    .reindex(gene_annotations.index, fill_value=False)
    .assign(no_category=lambda x: x.sum(1) == 0)
)
gene_x_cog_category_matrix

### Sample Metadata

In [None]:
stool_meta = (
    pd.read_table(stool_inpath)
    # .rename(columns={'Seq-Name': 'sample', 'CED/Patient-recoded': 'subject_id', 'sampleDate': 'date', 'Diet (=PreEEN, EEN, PostEEN)': 'sample_type'})
    # .assign(
    #     date=lambda x: pd.to_datetime(x.date),
    #     sample_type=lambda x: x.sample_type.fillna('???')
    # )
    .set_index("mgen_id")
    # .sort_values(['subject_id', 'date', 'sample_type'])
)

# FIXME: Metadata seems to include a swap in the metagenomic data of CF_11 and CF_15.
stool_meta = stool_meta.rename({"CF_11": "CF_15", "CF_15": "CF_11"})
stool_meta

In [None]:
microcosm_meta = (
    pd.read_table(microcosm_inpath)
    .set_index("mgen_id")
    .rename(columns={"inoculum_subject_id": "subject_id"})
)
microcosm_meta

In [None]:
meta = pd.concat(
    [
        stool_meta.assign(
            label=lambda x: x.assign(idx=x.index)[
                ["idx", "collection_date_relative_een_end", "sample_type"]
            ].apply(tuple, axis=1)
        ),
        microcosm_meta.assign(
            collection_date_relative_een_end=np.inf,
            sample_type=lambda x: x.inoculum_mgen_id.map(lambda s: f"invitro"),
            label=lambda x: x.assign(idx=x.index)[
                ["idx", "inoculum_mgen_id", "sample_type"]
            ].apply(tuple, axis=1),
        ),
    ]
)

In [None]:
meta[["subject_id", "sample_type"]].value_counts().unstack(fill_value=0).sort_values(
    "EEN", ascending=False
).head(20)

In [None]:
subject_list = [
    "A",
    "B",
    "H",
    "C",
    "D",
    "E",
    "F",
    "G",
    "K",
    "L",
    "M",
    "N",
    "O",
    "P",
    "Q",
    "R",
    "S",
    "T",
    "U",
]

### Taxonomy

In [None]:
species_taxonomy = lib.thisproject.data.load_species_taxonomy(species_taxonomy_inpath)
species_taxonomy.loc["102506", "s__"] = "s__Escherichia coli"

species_taxonomy.loc[species_id]
species_name = species_taxonomy.loc[species_id].s__[len("s__") :].replace(" ", "_")
print(species_name)

### Strain Metadata

In [None]:
strain_meta = pd.read_table(spgc_meta_inpath, index_col="strain").rename(str)

In [None]:
x = strain_meta[lambda x: x.species_gene_frac > 0.9].num_genes
_df, _loc, _scale = sp.stats.t.fit(x.values, fix_df=2)
_dist0 = sp.stats.t(_df, _loc, _scale)
_dist1 = sp.stats.norm(_loc, _scale)

thresh_max_num_genes = _dist1.ppf(0.999)
thresh_min_num_genes = _dist1.ppf(0.001)


bins = np.linspace(0, x.max() * 1.5, num=50)
xx = np.linspace(0, x.max() * 1.5, num=1000)

plt.hist(x, bins=bins, alpha=0.2)

plt.plot(xx, _dist0.pdf(xx), color="k")
plt.plot(xx, _dist1.pdf(xx), color="k", linestyle="--")
plt.axvline(thresh_max_num_genes, lw=1, linestyle="--", color="k")
plt.axvline(thresh_min_num_genes, lw=1, linestyle="--", color="k")

In [None]:
thresh_min_num_genes, thresh_max_num_genes

In [None]:
# Gene count affine-tranformed so that min_genes = 0, max_genes = 1
scaled_gene_count = (strain_meta.num_genes - thresh_min_num_genes) / (
    thresh_max_num_genes - thresh_min_num_genes
)

plt.scatter(
    "sum_depth",
    "species_gene_frac",
    c=scaled_gene_count,
    data=strain_meta,
    norm=mpl.colors.PowerNorm(1 / 1, vmin=-0.2, vmax=1.2),
)
plt.xscale("log")
plt.yscale("logit")
plt.axvline(1, lw=1, color="k", linestyle="--")
plt.axhline(0.9, lw=1, color="k", linestyle="--")
plt.colorbar()

In [None]:
strain_list_filt = idxwhere(
    (strain_meta.sum_depth > 1)
    & (strain_meta.species_gene_frac > 0.9)
    & (strain_meta.num_genes <= thresh_max_num_genes)
    & (strain_meta.num_genes >= thresh_min_num_genes)
)
strain_meta.loc[strain_list_filt]

## Load Analysis

### Species Tracking

In [None]:
all_species_gtpro_depth = (
    pd.read_table(
        all_species_gtpro_depth_inpath,
        index_col=["sample", "species_id"],
    )
    .depth.unstack(fill_value=0)
    .rename(columns=str)
)

bins = np.linspace(0, 30_000, num=200)

fig, axs = plt.subplots(2, sharex=True)

for (title, x), ax in zip(
    dict(
        total_depth_by_sample=all_species_gtpro_depth.sum(1),
        mean_depth_by_species=all_species_gtpro_depth.mean(0),
    ).items(),
    axs.flatten(),
):
    ax.hist(x, bins=np.logspace(-3, 4, num=100))
    ax.set_title(title)
    ax.set_xscale("log")
fig.tight_layout()

species_rabund = all_species_gtpro_depth.divide(all_species_gtpro_depth.sum(1), axis=0)
(species_rabund > 1e-4).mean().sort_values(ascending=False).to_frame().join(
    species_taxonomy[["f__", "g__", "s__"]]
).head(10)

### Strain Tracking

In [None]:
sample_to_strain = pd.read_table(
    sample_to_spgc_inpath, index_col="sample"
).strain.astype(str)

spgc_palette = lib.plot.construct_ordered_palette(
    sample_to_strain, other=(0.8, 0.8, 0.8, 1.0)
)

d = (
    sample_to_strain.to_frame()
    .assign(subject_id=meta.subject_id)
    .value_counts()
    .unstack("subject_id", fill_value=0)
)
cg = sns.clustermap(d, annot=d, norm=mpl.colors.PowerNorm(1 / 2))
cg.ax_cbar.set_visible(False)

In [None]:
world = sf.data.World.load(
    sfacts_fit_inpath,
)
position_ss = world.random_sample(position=1000).position
# world.data['strain'] = world.data.strain.to_series().map(str).to_xarray()

In [None]:
# Genotype similarity ordered palette:
strain_linkage = world.genotype.linkage(optimal_ordering=True)
strain_order = linkage_order(strain_linkage, world.strain.values)
strain_palette = lib.plot.construct_ordered_palette(
    strain_order,
    cm="rainbow",
)

In [None]:
w = world.sel(position=position_ss)
_sample_linkage = w.unifrac_linkage(optimal_ordering=True)

sf.plot.plot_metagenotype(
    w,
    col_linkage=_sample_linkage,
    col_colors=(
        w.sample.to_series()
        .map(sample_to_strain)
        .map(lambda x: np.nan_to_num(x, nan=-1).astype(int))
        .map(strain_palette)
    ),
    scalex=0.3,
)
sf.plot.plot_community(
    w,
    col_linkage=_sample_linkage,
    col_colors=(
        w.sample.to_series()
        .map(sample_to_strain)
        .map(lambda x: np.nan_to_num(x, nan=-1).astype(int))
        .map(strain_palette)
    ),
    row_linkage=strain_linkage,
    row_colors=w.strain.to_series().map(strain_palette),
    scalex=0.3,
    scaley=0.3,
)
sf.plot.plot_genotype(
    w,
    row_linkage=strain_linkage,
    # row_linkage_func=lambda w: w.genotype.linkage(optimal_ordering=True),
    row_colors=w.strain.to_series().map(strain_palette),
    scaley=0.3,
)

In [None]:
_species_rabund = species_rabund[species_id]
_frac = world.community.to_pandas()

d0 = (
    meta.assign(
        species_rabund=_species_rabund,
    )
    .join(_frac)
    .sort_values(["subject_id", "collection_date_relative_een_end", "sample_type"])
)

fig, axs = lib.plot.subplots_grid(
    ncols=4,
    naxes=len(subject_list),
    ax_width=8,
    ax_height=6,
)
fig.suptitle(species_name)
for subject_id, ax in zip(subject_list, axs.flatten()):
    ax.set_title(subject_id)
    twin_ax = ax.twinx()
    d1 = d0[lambda x: x.subject_id == subject_id].set_index("label")
    d1[strain_order].plot(
        kind="bar",
        width=0.95,
        stacked=True,
        color=strain_palette,
        ax=ax,
    )
    d1.species_rabund.plot(kind="line", ax=twin_ax, color="k")
    ax.legend_.set_visible(False)
    ax.set_ylim(0, 1)
    ax.set_ylabel("strain fraction")
    ax.set_xlabel("")
    twin_ax.set_ylabel("species relative abundance")
    twin_ax.set_ylim(0)
    lib.plot.rotate_xticklabels(ax)
    ax.set_xlim(-0.5, 14)
fig.tight_layout()

In [None]:
_palette = strain_palette.copy()
for strain in strain_palette:
    if str(strain) not in strain_list_filt:
        del _palette[strain]

fig, axs = lib.plot.subplots_grid(
    ncols=4,
    naxes=len(subject_list),
    ax_width=8,
    ax_height=6,
)
fig.suptitle(species_name)
for subject_id, ax in zip(subject_list, axs.flatten()):
    ax.set_title(subject_id)
    twin_ax = ax.twinx()
    d1 = d0[lambda x: x.subject_id == subject_id].set_index("label")
    d1[strain_order].plot(
        kind="bar",
        width=0.95,
        stacked=True,
        color=_palette,
        ax=ax,
    )
    d1.species_rabund.plot(kind="line", ax=twin_ax, color="k")
    ax.legend_.set_visible(False)
    ax.set_ylim(0, 1)
    ax.set_ylabel("strain fraction")
    ax.set_xlabel("")
    twin_ax.set_ylabel("species relative abundance")
    twin_ax.set_ylim(0)
    lib.plot.rotate_xticklabels(ax)
    ax.set_xlim(-0.5, 14)
fig.tight_layout()

In [None]:
species_depth = pd.read_table(
    species_depth_inpath, names=["sample", "depth"], index_col=["sample"]
).depth
species_depth

In [None]:
pure_strain_agg_mgtp = sf.Metagenotype(
    world.metagenotype.data.sel(sample=sample_to_strain.index)
    .groupby(sample_to_strain.to_xarray())
    .sum()
    .rename(strain="sample")
)

sf.plot.plot_metagenotype(pure_strain_agg_mgtp.sel(position=position_ss))
sf.plot.plot_dominance(pure_strain_agg_mgtp.sel(position=position_ss))

In [None]:
spgc_est_geno = pure_strain_agg_mgtp.to_estimated_genotype(pseudo=0)
g = sf.data.Genotype.concat(dict(mgen=spgc_est_geno, fit=world.genotype), dim="strain")
g_pdist = g.pdist()
g_linkage = g.linkage()
sf.plot.plot_genotype(
    g.sel(position=position_ss), transpose=True, col_linkage=g_linkage
)

### Strain Genes

In [None]:
gene_x_sample_depth = xr.load_dataarray(raw_depth_inpath)
gene_x_sample_depth_ratio = gene_x_sample_depth / species_depth.to_xarray().sel(
    sample=gene_x_sample_depth.sample
)

In [None]:
strain_gene_depth = pd.read_table(spgc_depth_inpath, index_col=['gene_id', 'strain'], dtype={'strain': str}).depth.unstack(fill_value=0)
strain_gene_corr = pd.read_table(spgc_corr_inpath, index_col=['gene_id', 'strain'], dtype={'strain': str}).correlation.unstack(fill_value=0)

In [None]:
strain_genes = pd.read_table(spgc_hits_inpath, index_col=["gene_id"])
strain_genes

In [None]:
ref_genes = (xr.load_dataarray(ref_gene_copies_inpath).to_pandas().T >= 1).astype(float)
gene_corr_cluster = pd.Series(sp.cluster.hierarchy.fcluster(sp.cluster.hierarchy.linkage(ref_genes, method='average', metric='cosine'), criterion='distance', t=0.05), index=ref_genes.index)
gene_corr_cluster_size = gene_corr_cluster.value_counts()
multi_gene_clusters = idxwhere(gene_corr_cluster_size > 1)

In [None]:
sns.clustermap(ref_genes.groupby(gene_corr_cluster[lambda x: x.isin(multi_gene_clusters)]).mean())

In [None]:
_strain_list = strain_list_filt
_gene_list = idxwhere(strain_genes[strain_list_filt].apply(lambda x: (x.sum() > 1) & (x.mean() < 1.0), axis=1))
x = strain_genes.reindex(index=_gene_list, columns=_strain_list, fill_value=0)
_col_linkage = (
    pure_strain_agg_mgtp.to_estimated_genotype()
    .sel(strain=_strain_list)
    .linkage("strain")
)

_row_linkage = sp.cluster.hierarchy.linkage(x, method='average', metric='cosine')

# _row_colors = (
#     pd.DataFrame(index=_gene_list)
#     .assign(
#         species_gene=lambda x: x.index.isin(species_genes),
#     )
#     # .join(gene_x_cog_category_matrix)
#     .apply(lambda x: x.map({True: "black", False: "grey"}))
# )
_col_colors = pd.Series([int(x) for x in _strain_list], index=_strain_list).map(strain_palette)

sns.clustermap(
    x,
    col_linkage=_col_linkage,
    row_linkage=_row_linkage,
    # row_colors=_row_colors,
    col_colors=_col_colors
)

In [None]:
x2 = strain_gene_depth.reindex(index=_gene_list, columns=_strain_list, fill_value=0)

sns.clustermap(
    x2,
    col_linkage=_col_linkage,
    row_linkage=_row_linkage,
    # row_colors=_row_colors,
    col_colors=_col_colors,
    norm=mpl.colors.PowerNorm(1/3, vmin=0, vmax=10),
)

In [None]:
x2 = strain_gene_corr.reindex(index=_gene_list, columns=_strain_list, fill_value=0)

sns.clustermap(
    x2,
    col_linkage=_col_linkage,
    row_linkage=_row_linkage,
    # row_colors=_row_colors,
    col_colors=_col_colors
)

In [None]:
strain_gene_cluster_gene_fraction = strain_genes.reindex(index=gene_corr_cluster.index, fill_value=0).groupby(gene_corr_cluster[lambda x: x.isin(multi_gene_clusters)]).mean()
strain_gene_cluster_hits = strain_gene_cluster_gene_fraction >= 0.5

In [None]:
_strain_list = strain_list_filt
x = strain_gene_cluster_gene_fraction[_strain_list]
_col_linkage = (
    pure_strain_agg_mgtp.to_estimated_genotype()
    .sel(strain=_strain_list)
    .linkage("strain")
)

_row_linkage = sp.cluster.hierarchy.linkage(x, method='average')
_col_colors = pd.Series([int(x) for x in _strain_list], index=_strain_list).map(strain_palette)

sns.clustermap(
    x,
    col_linkage=_col_linkage,
    row_linkage=_row_linkage,
    col_colors=_col_colors
)

In [None]:
plt.hist(strain_gene_cluster_hits[_strain_list].sum(1), bins=np.arange(10))

In [None]:
x, y = lib.pandas_util.align_indexes(gene_corr_cluster_size[multi_gene_clusters], strain_gene_cluster_hits[strain_list_filt].mean(1))
plt.hist2d(x, y, norm=mpl.colors.PowerNorm(1/3), bins=(np.logspace(0, 3), np.linspace(0, 1, num=11)))
plt.xscale('log')
None

In [None]:
gene_cluster_cog_category_count = gene_x_cog_category.groupby(gene_corr_cluster).value_counts()
gene_cluster_cog_category_count.sort_values(ascending=False).head(10)

In [None]:
gene_cluster_kpathway_count = gene_x_kpathway.groupby(gene_corr_cluster).value_counts()
# most_kmodule_like_cluster_list= gene_cluster_kmodule_count.max().sort_values(ascending=False).head(20).index
gene_cluster_kpathway_count.sort_values(ascending=False).head(100).tail(50)

In [None]:
pd.DataFrame(dict(x=x, y=y))[lambda x: (x.x > 10) & (x.y > 0.25)].rename(int).assign(cluster_top=gene_cluster_cog_category_count.unstack(fill_value=0).apply(lambda x: idxwhere(x.sort_values(ascending=False) > 1)[:5], axis=1))

In [None]:
gene_annotations.loc[gene_corr_cluster == 2812]

### Bile Acid genes

In [None]:
gene_annotations.loc[idxwhere(gene_x_cog == 'COG1028')]

In [None]:
bile_acid_ko_list = [
"K00076",
"K01442",
"K07007",
"K15868",
"K15869",
"K15870",
"K15871",
"K15872",
"K15873",
"K15874",
"K22604",
"K22605",
"K22606",
"K22607",
"K23231",
]

In [None]:
bile_acid_ec_list = [
"1.1.1.159",
"1.1.1.176",
"1.1.1.201",
"1.1.1.238",
"1.1.1.391",
"1.1.1.392",
"1.1.1.393",
"1.1.1.395",
"1.1.1.52",
"1.3.1.114",
"1.3.1.115",
"1.3.1.116",
"2.8.3.25",
"3.5.1.24",
"3.5.1.74",
"4.2.1.106",
"6.2.1.7",
]

In [None]:

idxwhere(gene_x_ko.isin(bile_acid_ko_list))

In [None]:
idxwhere(gene_annotations.KEGG_inpathway.fillna('').str.contains('ko00121'))

In [None]:
idxwhere(gene_x_ec.isin(bile_acid_ec_list))

In [None]:
_strain_list = strain_list_filt
_gene_list = idxwhere(gene_x_ko.isin(bile_acid_ko_list))
x = strain_genes.reindex(index=_gene_list, columns=_strain_list, fill_value=0)
_col_linkage = (
    pure_strain_agg_mgtp.to_estimated_genotype()
    .sel(strain=_strain_list)
    .linkage("strain")
)

# _row_colors = (
#     pd.DataFrame(index=_gene_list)
#     .assign(
#         species_gene=lambda x: x.index.isin(species_genes),
#     )
#     # .join(gene_x_cog_category_matrix)
#     .apply(lambda x: x.map({True: "black", False: "grey"}))
# )
_col_colors = pd.Series([int(x) for x in _strain_list], index=_strain_list).map(strain_palette)

sns.clustermap(
    x,
    col_linkage=_col_linkage,
    # row_colors=_row_colors,
    col_colors=_col_colors
)

#### Confirm Depth Estimates (FIXME: Drop)

In [None]:
gtpro_species_depth = all_species_gtpro_depth[species_id]
gtpro_species_depth

In [None]:
species_genes = lib.pandas_util.read_list(
    f"data/species/sp-{species_id}/midasuhgg.pangenome.gene75_new.spgc_specgene-ref-t25-p95.species_gene.list"
)
# species_genes_denovo = lib.pandas_util.read_list(
#     f"data/group/een/species/sp-{species_id}/r.proc.gene99_new-v22-agg75.spgc_specgene-denovo2-t30-n500.species_gene.list"
# )

In [None]:
_species_depth = (
    gene_x_sample_depth.sel(gene_id=species_genes)
    .to_pandas()
    .T.apply(sp.stats.trim_mean, proportiontocut=0.15)
)
# _species_depth2 = (
#     gene_x_sample_depth.sel(gene_id=species_genes_denovo)
#     .to_pandas()
#     .T.apply(sp.stats.trim_mean, proportiontocut=0.15)
# )

In [None]:
x, y = lib.pandas_util.align_indexes(
    gtpro_species_depth, species_depth, how="outer", fill_value=0
)

plt.scatter(x, y)
plt.plot([1e-2, 2e2], [1e-2, 2e2])
plt.yscale("symlog", linthresh=1e-2)
plt.xscale("symlog", linthresh=1e-2)

In [None]:
(species_depth[lambda x: x > 1] / gtpro_species_depth).dropna().sort_values()

In [None]:
_gene_depth, _species_depth = lib.pandas_util.align_indexes(
    gene_x_sample_depth.to_pandas(), _species_depth[lambda x: x > 0.5]
)
_gene_depth_ratio = _gene_depth.divide(_species_depth, axis=0).median()
bins = np.logspace(-3, 1)
plt.hist(_gene_depth_ratio, bins=bins)
# plt.hist(_gene_depth_ratio[species_genes_denovo], bins=bins, alpha=0.5)
plt.hist(_gene_depth_ratio[species_genes], bins=bins, alpha=0.5)
plt.xscale("log")
plt.yscale("log")