## Preamble

### Project Template

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

### Imports

In [None]:
import os
import subprocess
import sys
import time
from datetime import datetime
from glob import glob
from itertools import chain, product
from tempfile import mkstemp

import matplotlib as mpl
import matplotlib.pyplot as plt
import mpltern
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import sfacts as sf
import statsmodels.api as sm
import statsmodels.formula.api as smf
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.spatial.distance import pdist, squareform
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
from lib.dissimilarity import dmatrix, load_dmat_as_pickle
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

In [None]:
import lib.thisproject.data

### Set Style

In [None]:
sns.set_context("paper")
plt.rcParams["figure.dpi"] = 300

In [None]:
genome_type_palette = {
    "SPGC": "tab:green",
    "MAG": "tab:orange",
    "Isolate": "tab:blue",
    "Ref": "black",
}

## Data Setup

### Metadata

In [None]:
species_list = (
    pd.read_table("meta/species_group.tsv")[lambda x: x.species_group_id == "hmp2"]
    .species_id.astype(str)
    .unique()
)

In [None]:
def parse_taxonomy_string(taxonomy_string):
    values = taxonomy_string.split(";")
    return pd.Series(values, index=["d__", "p__", "c__", "o__", "f__", "g__", "s__"])

In [None]:
species_taxonomy_inpath = "ref/uhgg_genomes_all_v2.tsv"

species_taxonomy = (
    pd.read_table(species_taxonomy_inpath)[lambda x: x.Genome == x.Species_rep]
    .assign(species_id=lambda x: "1" + x.MGnify_accession.str.split("-").str[2])
    .set_index("species_id")
    .Lineage.apply(parse_taxonomy_string)
)
species_taxonomy

### Strain Statistics

In [None]:
def classify_genome(x):
    if (x.genome_type == "Isolate") & x.passes_filter:
        return "isolate"
    elif (x.genome_type == "Isolate") & ~x.passes_filter:
        return "isolate_fails_qc"
    elif (x.genome_type == "MAG") & x.passes_filter:
        return "mag"
    elif (x.genome_type == "MAG") & ~x.passes_filter:
        return "mag_fails_qc"
    elif (x.genome_type == "SPGC") & x.passes_filter:
        return "spgc"
    elif (x.genome_type == "SPGC") & x.passes_geno_positions:
        return "sfacts_only"
    elif (x.genome_type == "SPGC") & ~(x.passes_geno_positions):
        return "sfacts_fails_qc"
    else:
        raise ValueError("Genome did not match classification criteria:", x)

In [None]:
filt_stats = []
missing_species = []

_species_list = species_list
# _species_list = ["100003"]

for species in tqdm(_species_list):
    inpath = f"data/group/hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_v20-v23-agg75.spgc-fit.strain_meta_spgc_and_ref.tsv"
    if not os.path.exists(inpath):
        missing_species.append(species)
        continue
    data = pd.read_table(inpath).assign(species=species, inpath=inpath)
    filt_stats.append(data)
filt_stats = (
    pd.concat(filt_stats)
    .assign(
        genome_class=lambda x: x.apply(classify_genome, axis=1),
        species_strain=lambda x: x.species + "_" + x.genome_id,
    )
    .set_index("species_strain")
)


print(
    len(missing_species),
    "out of",
    len(_species_list),
    "species are missing stats.",
)

In [None]:
# Define different subsets of the species:

# All species:
# species_list
spgc_strain_list = filt_stats[lambda x: x.genome_type.isin(["SPGC"])].index.values

# All species with enough positions
species_list0 = filt_stats[lambda x: x.passes_geno_positions].species.unique()
spgc_strain_list0 = filt_stats[
    lambda x: x.passes_geno_positions
    & x.genome_type.isin(["SPGC"])
    & x.species.isin(species_list0)
].index.values

# All species with sf strains
species_list1 = filt_stats[
    lambda x: x.passes_geno_positions & x.genome_type.isin(["SPGC"])
].species.unique()
spgc_strain_list1 = filt_stats[
    lambda x: x.passes_geno_positions
    & x.genome_type.isin(["SPGC"])
    & x.species.isin(species_list1)
].index.values

# All species with sf strains to talk about distributions (>=10)
species_list1b = idxwhere(
    filt_stats[
        lambda x: x.passes_geno_positions & x.genome_type.isin(["SPGC"])
    ].species.value_counts()
    >= 10
)
spgc_strain_list1b = filt_stats[
    lambda x: x.passes_geno_positions
    & x.genome_type.isin(["SPGC"])
    & x.species.isin(species_list1b)
].index.values

# All species with spgc strains
species_list2 = filt_stats[
    lambda x: x.passes_filter & x.genome_type.isin(["SPGC"])
].species.unique()
spgc_strain_list2 = filt_stats[
    lambda x: x.passes_filter
    & x.genome_type.isin(["SPGC"])
    & x.species.isin(species_list2)
].index.values

# All species with enough spgc strains for pangenome analysis (>=10)
species_list3 = idxwhere(
    filt_stats[
        lambda x: x.passes_filter & x.genome_type.isin(["SPGC"])
    ].species.value_counts()
    >= 10
)
spgc_strain_list3 = filt_stats[
    lambda x: x.passes_filter
    & x.genome_type.isin(["SPGC"])
    & x.species.isin(species_list3)
].index.values

# Species with large numbers of strains (>=20)
species_list4 = idxwhere(
    filt_stats[
        lambda x: x.passes_filter & x.genome_type.isin(["SPGC"])
    ].species.value_counts()
    >= 20
)
spgc_strain_list4 = filt_stats[
    lambda x: x.passes_filter
    & x.genome_type.isin(["SPGC"])
    & x.species.isin(species_list4)
].index.values

_species_list_map = {
    "All considered species": (species_list, spgc_strain_list),
    "0: Species with enough genotyped positions": (species_list0, spgc_strain_list0),
    "1: With sfacts strains": (species_list1, spgc_strain_list1),
    "1b: With (>=10) sfacts strains": (species_list1b, spgc_strain_list1b),
    "2: With SPGC inferences": (species_list2, spgc_strain_list2),
    "3: With >=10 inferences": (species_list3, spgc_strain_list3),
    "4: With >=20 inferences": (species_list4, spgc_strain_list4),
}
for _species_list_name, (_species_list, _strain_list) in _species_list_map.items():
    print(_species_list_name, len(_species_list), len(_strain_list))
    print(species_taxonomy.loc[_species_list].p__.value_counts())
    print()

In [None]:
filt_stats.loc[spgc_strain_list3].species.value_counts().quantile([0.25, 0.5, 0.75])

## Phylum palette

In [None]:
species_taxonomy.loc[species_list1].p__.value_counts()

In [None]:
phylum_order = [
    "p__Euryarchaeota",
    "p__Thermoplasmatota",
    "p__Firmicutes",
    "p__Firmicutes_A",
    "p__Firmicutes_C",
    # "p__Firmicutes_B", # None in species_list1
    # "p__Firmicutes_G", # B/G/I not sure how related to C or A
    # "p__Firmicutes_I", #
    # "p__Cyanobacteria", # None in species_list1
    "p__Actinobacteriota",
    "p__Synergistota",
    "p__Fusobacteriota",
    "p__Campylobacterota",
    "p__Proteobacteria",
    "p__Desulfobacterota_A",
    "p__Bacteroidota",
    "p__Verrucomicrobiota",
    # "dummy0", # 18
    # "dummy1", # 19
    # "dummy2", # 20
]

phylum_palette = lib.plot.construct_ordered_palette(
    phylum_order,
    cm="rainbow",
    desaturate_levels=[1.0, 0.5],
)

for p__ in phylum_order:
    print(p__, phylum_palette[p__])
    plt.scatter([], [], color=phylum_palette[p__], label=p__)
plt.legend(ncols=4)
lib.plot.hide_axes_and_spines()

# assert len(set(phylum_palette.values())) == len((phylum_palette.values()))

## Prevalences

In [None]:
spgc_gene_prevalence = []
missing_species = []

_species_list = species_list3

for species in tqdm(_species_list):
    inpath = f"data/group/hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_v20-v23-agg75.spgc-fit.uhgg-strain_gene.prevalence.tsv"
    if not os.path.exists(inpath):
        missing_species.append(species)
        continue
    data = pd.read_table(
        inpath, names=["gene_id", "prevalence"], index_col="gene_id"
    ).prevalence
    spgc_gene_prevalence.append(data)
spgc_gene_prevalence = pd.concat(spgc_gene_prevalence).sort_index()

print(
    len(missing_species),
    "out of",
    len(_species_list),
    "species are missing stats.",
)

In [None]:
ref_gene_prevalence = []
missing_species = []

_species_list = species_list3

for species in tqdm(_species_list):
    inpath = f"data/species/sp-{species}/midasdb.gene75_v20.uhgg-strain_gene.ref_prevalence.tsv"
    if not os.path.exists(inpath):
        missing_species.append(species)
        continue
    data = pd.read_table(
        inpath, names=["gene_id", "prevalence"], index_col="gene_id"
    ).prevalence
    ref_gene_prevalence.append(data)
ref_gene_prevalence = pd.concat(ref_gene_prevalence).sort_index()

print(
    len(missing_species),
    "out of",
    len(_species_list),
    "species are missing stats.",
)

### Figure 4A

In [None]:
exclude_genes_never_greater_than = 0.01
d = pd.DataFrame(dict(ref=ref_gene_prevalence, spgc=spgc_gene_prevalence)).fillna(0)[
    lambda x: x.max(1) > exclude_genes_never_greater_than
]

fig, ax = plt.subplots(figsize=(3.5, 3))
bins = np.linspace(0, 1, num=51)
*_, art = ax.hist2d(
    "ref",
    "spgc",
    data=d,
    bins=bins,
    norm=mpl.colors.SymLogNorm(1, vmin=1, vmax=1e4),
    cmap="Grays",
)
fig.colorbar(art, ax=ax, label="Genes (count)", shrink=0.94)
ax.plot([0, 1], [0, 1], lw=2, linestyle=":", color="tab:red")


ax.set_aspect(1)
ax.set_xlabel("Reference Prevalence")
ax.set_ylabel("SPGC Prevalence")
ax.set_xticks([0, 0.15, 0.9, 1.0])
ax.set_yticks([0, 0.15, 0.9, 1.0])
ax.xaxis.set_major_formatter(mpl.ticker.PercentFormatter(xmax=1))
ax.yaxis.set_major_formatter(mpl.ticker.PercentFormatter(xmax=1))

ax.annotate("core", xy=(0.5, 0.901), ha="center", va="bottom", color="tab:red")
ax.axhline(0.9, lw=1, linestyle="--", color="tab:red")
ax.annotate("shell", xy=(0.5, 0.151), ha="center", va="bottom", color="tab:red")
ax.axhline(0.15, lw=1, linestyle="--", color="tab:red")
ax.annotate("cloud", xy=(0.5, 0.0), ha="center", va="bottom", color="tab:red")
lib.plot.rotate_xticklabels(ax=ax, rotation=90, ha="center")
# lib.plot.rotate_yticklabels(ax=ax, rotation=45, va='top')

sp.stats.pearsonr(d.ref, d.spgc)

fig.savefig("fig/fig4a_gene_prevalence.pdf", bbox_inches="tight")

## Genome Fractions

### Distribution of genome fractions in inferred strains

In [None]:
# TODO: Gather genome fractions for strains
# Filter strains

spgc_prevalence_class_counts = []
missing_species = []

_species_list = species_list2

for species in tqdm(_species_list):
    inpath = f"data/group/hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_v20-v23-agg75.spgc-fit.uhgg-strain_gene.prevalence_class_fraction.tsv"
    if not os.path.exists(inpath):
        missing_species.append(species)
        continue

    strain_list = filt_stats[
        lambda x: (x.species == species) & (x.passes_filter) & (x.genome_type == "SPGC")
    ].genome_id.unique()
    data = (
        pd.read_table(inpath, index_col="strain")
        .rename(index=str)
        .loc[strain_list]
        .assign(species=species)
        .reset_index()
        .set_index(["species", "strain"])
    )
    spgc_prevalence_class_counts.append(data)
spgc_prevalence_class_counts = pd.concat(spgc_prevalence_class_counts).sort_index()

print(
    len(missing_species),
    "out of",
    len(_species_list),
    "species are missing stats.",
)

In [None]:
median_prevalence_class_fraction = (
    spgc_prevalence_class_counts.groupby("species")
    .median()
    .apply(lambda x: x / x.sum(), axis=1)
)
median_prevalence_class_fraction

### Figure 4B

In [None]:
d0 = median_prevalence_class_fraction.loc[species_list3].assign(
    p__=lambda x: x.index.to_series().map(species_taxonomy.p__),
    num_genomes=spgc_prevalence_class_counts.reset_index()["species"].value_counts(),
)

fig, ax = plt.subplots(figsize=(3, 3))
_, ax_legend = plt.subplots()

for p__ in phylum_order:
    d1 = d0[lambda x: x.p__ == p__]
    ax.scatter(
        "core",
        "shell",
        data=d1,
        color=phylum_palette[p__],
        facecolor="none",
        s=15,
        label="__nolegend__",
        # marker="o",
        # lw=2,
        # facecolor="none",
        # alpha=0.85,
    )
    ax_legend.scatter(
        [], [], color=phylum_palette[p__], facecolor="none", label=p__, s=50, lw=3
    )
ax_legend.legend(bbox_to_anchor=(1, 0.5))
lib.plot.hide_axes_and_spines(ax_legend)


for cloud_frac, shift in [(0.0, 0.06), (0.1, 0.04), (0.2, 0.04), (0.3, 0.04)]:
    if cloud_frac == 0.3:
        annot = f"Cloud {cloud_frac:.0%}"
    else:
        annot = f"{cloud_frac:.0%}"
    ax.plot(
        [0, 1 - cloud_frac],
        [1 - cloud_frac, 0],
        lw=1,
        linestyle=":",
        color="k",
    )
    ax.annotate(
        annot,
        xy=(1 - cloud_frac - shift, -0.04 + shift),
        rotation=-45,
        va="bottom",
        ha="right",
        fontsize=10,
    )

ax.set_xlim(0.35, 0.95)
ax.set_ylim(0.0, 0.6)
ax.xaxis.set_major_formatter(mpl.ticker.PercentFormatter(xmax=1))
ax.yaxis.set_major_formatter(mpl.ticker.PercentFormatter(xmax=1))
ax.set_xlabel("Core")
ax.set_ylabel("Shell")
ax.set_aspect(1)

fig.savefig("fig/fig4b_genome_fraction.pdf", bbox_inches="tight")

In [None]:
median_prevalence_class_fraction.loc[species_list3].quantile([0.25, 0.5, 0.75])

### Distribution of genome fractions in reference strains

In [None]:
# TODO: Gather genome fractions for strains
# Filter strains

ref_prevalence_class_counts = []
missing_species = []

_species_list = species_list2

for species in tqdm(_species_list):
    inpath = f"data/species/sp-{species}/midasdb.gene75_v20.uhgg-strain_gene.prevalence_class_fraction.tsv"
    if not os.path.exists(inpath):
        missing_species.append(species)
        continue

    strain_list = filt_stats[
        lambda x: (x.species == species)
        & (x.passes_filter)
        & (x.genome_type.isin(["Isolate", "MAG"]))
    ].genome_id.unique()
    data = (
        pd.read_table(inpath, index_col="strain")
        .rename(index=str)
        .loc[strain_list]
        .assign(species=species)
        .reset_index()
        .set_index(["species", "strain"])
    )
    ref_prevalence_class_counts.append(data)
ref_prevalence_class_counts = pd.concat(ref_prevalence_class_counts).sort_index()

print(
    len(missing_species),
    "out of",
    len(_species_list),
    "species are missing stats.",
)

In [None]:
ref_median_prevalence_class_fraction = (
    ref_prevalence_class_counts.groupby("species")
    .median()
    .apply(lambda x: x / x.sum(), axis=1)
)
ref_median_prevalence_class_fraction

In [None]:
d0 = ref_median_prevalence_class_fraction.loc[species_list3].assign(
    p__=lambda x: x.index.to_series().map(species_taxonomy.p__),
    num_genomes=ref_prevalence_class_counts.reset_index()["species"].value_counts(),
)

fig, ax = plt.subplots(figsize=(3, 3))
_, ax_legend = plt.subplots()

for p__ in phylum_order:
    d1 = d0[lambda x: x.p__ == p__]
    ax.scatter(
        "core",
        "shell",
        data=d1,
        color=phylum_palette[p__],
        facecolor="none",
        s=15,
        label="__nolegend__",
        # marker="o",
        # lw=2,
        # facecolor="none",
        # alpha=0.85,
    )
    ax_legend.scatter(
        [], [], color=phylum_palette[p__], facecolor="none", label=p__, s=50, lw=3
    )
ax_legend.legend(bbox_to_anchor=(1, 0.5))
lib.plot.hide_axes_and_spines(ax_legend)


for cloud_frac, shift in [(0.0, 0.06), (0.1, 0.04), (0.2, 0.04), (0.3, 0.04)]:
    if cloud_frac == 0.3:
        annot = f"Cloud {cloud_frac:.0%}"
    else:
        annot = f"{cloud_frac:.0%}"
    ax.plot(
        [0, 1 - cloud_frac],
        [1 - cloud_frac, 0],
        lw=1,
        linestyle=":",
        color="k",
    )
    ax.annotate(
        annot,
        xy=(1 - cloud_frac - shift, -0.04 + shift),
        rotation=-45,
        va="bottom",
        ha="right",
        fontsize=10,
    )

ax.set_xlim(0.35, 0.95)
ax.set_ylim(0.0, 0.6)
ax.xaxis.set_major_formatter(mpl.ticker.PercentFormatter(xmax=1))
ax.yaxis.set_major_formatter(mpl.ticker.PercentFormatter(xmax=1))
ax.set_xlabel("Core")
ax.set_ylabel("Shell")
ax.set_aspect(1)

fig.savefig("fig/fig_s2.pdf")

In [None]:
ref_median_prevalence_class_fraction.loc[species_list3].quantile([0.25, 0.5, 0.75])

## COG Cateogry Enrichment

In [None]:
cog_category_description = pd.read_table(
    "ref/cog-20.categories.tsv",
    names=["cog_category", "color", "description"],
    index_col="cog_category",
).description
cog_category_description.sort_index()

In [None]:
cog_category = []
for species in tqdm(species_list3):
    cog_category_inpath = (
        f"data/species/sp-{species}/midasdb_v20.emapper.gene75_x_cog_category.tsv"
    )
    cog_category.append(pd.read_table(cog_category_inpath))

cog_category = pd.concat(cog_category)

In [None]:
gene_x_cog_category_matrix = (
    cog_category.set_index(["centroid_75", "cog_category"])
    .assign(annotation=True)
    .unstack("cog_category", fill_value=False)
    .annotation
    # .reindex(spgc_gene_prevalence.index)
    # .fillna({'no_category': True})
    # .fillna(False)
)
gene_x_cog_category_matrix.sum()

In [None]:
# Double check that only genes with no other cog category get assigned "no_category".
assert not idxwhere(gene_x_cog_category_matrix[lambda x: x.no_category].sum(1) > 1)

In [None]:
spgc_prevalence_class = spgc_gene_prevalence.map(
    lambda x: np.where(
        x > 0.9, "core", np.where(x > 0.15, "shell", np.where(x > 0, "cloud", "absent"))
    )
)

In [None]:
d0 = (
    spgc_prevalence_class.to_frame("prevalence_class")
    .join(gene_x_cog_category_matrix)
    .assign(
        cloud=lambda x: x.prevalence_class == "cloud",
        shell=lambda x: x.prevalence_class == "shell",
        core=lambda x: x.prevalence_class == "core",
        absent=lambda x: x.prevalence_class == "absent",
    )
)[lambda x: ~x.absent]

result = []
for _prevalence_class, _cog_category in tqdm(
    list(
        product(
            ["core", "shell", "cloud", "absent"], gene_x_cog_category_matrix.columns
        )
    )
):
    d1 = (
        d0[[_prevalence_class, _cog_category]]
        .value_counts()
        .unstack()
        .reindex(index=[True, False], columns=[True, False])
        .fillna(0)
    )
    d1_pc = d1 + 1
    log_oddsratio = np.log2(
        (d1_pc.loc[True, True] / d1_pc.loc[True, False])
        / (d1_pc.loc[False, True] / d1_pc.loc[False, False])
    )
    result.append(
        (
            _prevalence_class,
            _cog_category,
            d1.loc[True, True],
            log_oddsratio,
            sp.stats.fisher_exact(d1)[1],
        )
    )
prevalence_class_cog_category_enrichment = pd.DataFrame(
    result,
    columns=[
        "prevalence_class",
        "cog_category",
        "num_genes",
        "log2_oddsratio",
        "pvalue",
    ],
).set_index(["prevalence_class", "cog_category"])

In [None]:
cog_category_label = {
    "J": "Ribosomes / Translation - J",
    "A": "RNA Processing - A",
    "K": "Transcription - K",
    "L": "DNA replication/recombination/repair - L",
    "B": "Chromatin - B",
    "D": "Cell cycle control - D",
    "Y": "Nucleus - Y",
    "V": "Defense - V",
    "T": "Signal transduction - T",
    "M": "Cell envelope - M",
    "N": "Motility - N",
    "Z": "Cytoskeleton - Z",
    "W": "Extracellular structures - W",
    "U": "Secretion / vesicular transport - U",
    "O": "Protein processing - O",
    "X": '"Mobilome" - X',
    "C": "Energy - C",
    "G": "Carbohydrates - G",
    "E": "Amino acids - E",
    "F": "Nucleotides - F",
    "H": "Coenzymes - H",
    "I": "Lipids - I",
    "P": "Inorganic ions - P",
    "Q": "Secondary metabolites - Q",
    "R": "General only - R",
    "S": "TODO: This shouldn't show up",
    "no_category": "Unknown",
}

In [None]:
prevalence_class_cog_category_enrichment.groupby(
    "cog_category"
).num_genes.sum().sort_values()

In [None]:
prevalence_class_cog_category_enrichment.num_genes.unstack("prevalence_class").loc[
    ["Y", "B", "A"]
].sum(
    1
)  # .sum()

### Figure 4C

In [None]:
d = prevalence_class_cog_category_enrichment

cog_category_drop_list = ["Y", "B", "A"]

cog_category_order = [
    c
    for c in d.xs("core").log2_oddsratio.sort_values(ascending=True).index
    if c not in cog_category_drop_list
]
cog_category_idx = pd.Series(
    np.arange(len(cog_category_order)), index=cog_category_order
).rename_axis("cog_category")
prevalence_class_order = ["core", "shell", "cloud"]
prevalence_class_idx = pd.Series(
    np.arange(len(prevalence_class_order)), index=prevalence_class_order
).rename_axis("prevalence_class")

num_genes_to_size = lambda x: 20 * np.log(x + 1)
signif_size = 20

d = (
    d.join(prevalence_class_idx.rename("prevalence_class_idx"))
    .join(cog_category_idx.rename("cog_category_idx"))
    .assign(
        num_genes_s=lambda x: num_genes_to_size(x.num_genes),
        signif=lambda x: signif_size * (x.pvalue >= 0.05),
    )
)

vmin, vmax = -np.log10(20), np.log10(20)

cmap = "coolwarm"
norm = mpl.colors.PowerNorm(1, vmin=vmin, vmax=vmax)

fig, ax = plt.subplots(figsize=(1, 8), facecolor="none")
ax.scatter(
    x="prevalence_class_idx",
    y="cog_category_idx",
    data=d,
    c="log2_oddsratio",
    s="num_genes_s",
    cmap=cmap,
    norm=norm,
    label="__nolegend__",
)
ax.scatter(
    x="prevalence_class_idx",
    y="cog_category_idx",
    data=d,
    s="signif",
    color="k",
    marker="x",
    label="__nolegend__",
    lw=1,
)
# for _, d1 in d.iterrows():
#     ax.annotate(d1.signif, xy=(d1.prevalence_class_idx, d1.cog_category_idx), ha='center', va='center')


ax.set_xlim(-0.5, len(prevalence_class_order) - 0.5)
ax.set_ylim(-1.0, len(cog_category_order))
ax.set_xticks(prevalence_class_idx)
ax.set_xticklabels(prevalence_class_order)
ax.set_yticks(cog_category_idx)
ax.set_yticklabels([cog_category_label[c] for c in cog_category_order])
lib.plot.rotate_yticklabels(ax=ax, rotation=30, va="top")
lib.plot.rotate_xticklabels(ax=ax)
# lib.plot.rotate_yticklabels(ax=ax, rotation=0, va='top')


# Remove frame
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.spines["left"].set_visible(False)

for num_genes in np.logspace(0, 4, num=5):
    ax.scatter(
        [],
        [],
        color="grey",
        edgecolor="black",
        label=num_genes,
        s=num_genes_to_size(num_genes),
    )
ax.legend(
    bbox_to_anchor=(1, 1),
    frameon=False,
    labelspacing=1,
)
fig.savefig("fig/fig4c_cogcat_enrichment.pdf", bbox_inches="tight")

fig, ax = plt.subplots(figsize=(1.5, 0.1))
fig.colorbar(
    mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
    cax=ax,
    extend="both",
    label="Enrichment",
    orientation="horizontal",
)
ax.set_xticks([-np.log10(20), -np.log10(5), 0, np.log10(5), np.log10(20)])
ax.set_xticklabels(["≤1:20", "1:5", "1:1", "5:1", "≥20:1"])
lib.plot.rotate_xticklabels(ax=ax)

fig.savefig("fig/fig4c_cogcat_enrichment_cbar.pdf", bbox_inches="tight")

In [None]:
d = prevalence_class_cog_category_enrichment

cog_category_order = d.xs("core").log2_oddsratio.sort_values(ascending=True).index
cog_category_idx = pd.Series(
    np.arange(len(cog_category_order)), index=cog_category_order
).rename_axis("cog_category")
prevalence_class_order = ["core", "shell", "cloud"]
prevalence_class_idx = pd.Series(
    np.arange(len(prevalence_class_order)), index=prevalence_class_order
).rename_axis("prevalence_class")

num_genes_to_size = lambda x: 55 * np.log(x + 1)
signif_size = 20

d = (
    d.join(prevalence_class_idx.rename("prevalence_class_idx"))
    .join(cog_category_idx.rename("cog_category_idx"))
    .assign(
        num_genes_s=lambda x: num_genes_to_size(x.num_genes),
        signif=lambda x: signif_size * (x.pvalue >= 0.05),
    )
)

d.xs("core").sort_values("log2_oddsratio", ascending=False).join(
    cog_category_description
)

In [None]:
sns.clustermap(1 - dmatrix(gene_x_cog_category_matrix.T, metric="correlation"))

## AMR Genes Analysis

In [None]:
amr_gene = []
missing = []

for species in tqdm(species_list):
    amr_gene_inpath = f"data/group/hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_v20-v23-agg75.spgc-fit.amr-strain_gene.tsv"
    if os.path.exists(amr_gene_inpath):
        amr_gene.append(
            pd.read_table(amr_gene_inpath, index_col="gene_id").rename(
                columns=lambda x: species + "_" + x
            )
        )
        assert amr_gene[-1].index.is_unique
    else:
        missing.append(species)

amr_gene = pd.concat(amr_gene, axis=1).fillna(0) > 0
print(len(missing))

In [None]:
_strain_list = spgc_strain_list2

d0 = (
    amr_gene.reindex(columns=_strain_list)
    .dropna(axis=1)
    .sum()
    .rename_axis(index="species_strain")
    .reset_index(name="num_amr_accessions")
    .assign(
        species=lambda x: x.species_strain.str.split("_").str[0],
        strain=lambda x: x.species_strain.str.split("_").str[1],
    )
    .join(species_taxonomy, on="species")
)

print((d0.num_amr_accessions > 0).mean())
plt.hist(d0.num_amr_accessions, density=True)

In [None]:
print((d0.num_amr_accessions > 0).sum(), (d0.num_amr_accessions >= 0).sum())

In [None]:
d0.assign(has_amr=lambda x: x.num_amr_accessions > 0).groupby("p__").has_amr.agg(
    ["sum", "count", "mean"]
).sort_values("mean", ascending=False)

In [None]:
d1 = (
    d0.assign(has_amr_accessions=lambda x: x.num_amr_accessions.gt(0))
    .groupby("species")
    .has_amr_accessions.agg(["sum", "mean", "count"])
    .rename(
        columns={
            "sum": "num_strains_with_amr",
            "mean": "frac_strains_with_amr",
            "count": "total_num_strains",
        }
    )
    # .to_frame("frac_strains_with_amr")
    .join(species_taxonomy)
)
d1
plt.hist(d1.frac_strains_with_amr, bins=20)

In [None]:
sns.stripplot(data=d1, x="p__", y="frac_strains_with_amr")
lib.plot.rotate_xticklabels()

In [None]:
is_amr_gene = []
for species in tqdm(species_list3):
    amr_annot_inpath = f"data/species/sp-{species}/midasdb_v20.gene75_x_amr.tsv"
    is_amr_gene.append(pd.read_table(amr_annot_inpath))

amr_gene_list = list(
    pd.concat(is_amr_gene).centroid_75.unique()
)  # .drop_duplicates(subset='centroid_75').assign(is_amr_gene=True).set_index('centroid_75').is_amr_gene

In [None]:
d0 = (
    spgc_prevalence_class[lambda x: x != "absent"]
    .to_frame()
    .assign(is_amr_gene=lambda x: x.index.to_series().isin(amr_gene_list))
)

for prevalence_class in ["core", "shell", "cloud", "absent"]:
    d1 = (
        d0.assign(is_prevalence_class=lambda x: x.prevalence == prevalence_class)[
            ["is_amr_gene", "is_prevalence_class"]
        ]
        .value_counts()
        .unstack("is_prevalence_class")
        .reindex(index=[False, True], columns=[False, True])
        .fillna(0)
    )
    d1_pc = d1 + 1
    log_odds_ratio = np.log2(
        (d1_pc.loc[True, True] / d1_pc.loc[True, False])
        / (d1_pc.loc[False, True] / d1_pc.loc[False, False])
    )
    print(prevalence_class, log_odds_ratio, sp.stats.fisher_exact(d1))
    print(d1)

In [None]:
d0.is_amr_gene.sum()

In [None]:
d0

In [None]:
d0[lambda x: (x.prevalence == "absent")]