## Preamble

### Project Template

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

### Imports

In [None]:
import os
import subprocess
import sys
import time
from datetime import datetime
from glob import glob
from itertools import chain, product
from tempfile import mkstemp

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
from scipy.spatial.distance import pdist, squareform
import seaborn as sns
import sfacts as sf
import statsmodels.api as sm
import statsmodels.formula.api as smf
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable
import mpltern

from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
from lib.dissimilarity import load_dmat_as_pickle
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

In [None]:
import lib.thisproject.data

### Set Style

In [None]:
sns.set_context("talk")
plt.rcParams["figure.dpi"] = 100

In [None]:
genome_type_palette = {"SPGC": "tab:green", "MAG": "tab:orange", "Isolate": "tab:blue", "Ref": "black"}

## Data Setup

### Metadata

In [None]:
species_list = (
    pd.read_table("meta/species_group.tsv")[
        lambda x: x.species_group_id == "hmp2"
    ]
    .species_id.astype(str)
    .unique()
)

In [None]:
def parse_taxonomy_string(taxonomy_string):
    values = taxonomy_string.split(";")
    return pd.Series(values, index=["d__", "p__", "c__", "o__", "f__", "g__", "s__"])

In [None]:
species_taxonomy_inpath = "ref/uhgg_genomes_all_v2.tsv"

species_taxonomy = (
    pd.read_table(species_taxonomy_inpath)[lambda x: x.Genome == x.Species_rep]
    .assign(species_id=lambda x: "1" + x.MGnify_accession.str.split("-").str[2])
    .set_index("species_id")
    .Lineage.apply(parse_taxonomy_string)
)
species_taxonomy

In [None]:
phylum_order0 = sorted(species_taxonomy.p__.unique())

phylum_palette0 = lib.plot.construct_ordered_palette(
    phylum_order0,
    cm="rainbow", desaturate_levels=[1.0, 0.5],
)

for p__ in phylum_order0:
    plt.scatter([], [], color=phylum_palette0[p__], label=p__)
plt.legend(ncols=4)
lib.plot.hide_axes_and_spines()

# assert len(set(phylum_palette.values())) == len((phylum_palette.values()))

### Strain Statistics

In [None]:
def classify_genome(x):
    if (x.genome_type == "Isolate") & x.passes_filter:
        return "isolate"
    elif (x.genome_type == "Isolate") & ~x.passes_filter:
        return "isolate_fails_qc"
    elif (x.genome_type == "MAG") & x.passes_filter:
        return "mag"
    elif (x.genome_type == "MAG") & ~x.passes_filter:
        return "mag_fails_qc"
    elif (x.genome_type == "SPGC") & x.passes_filter:
        return "spgc"
    elif (x.genome_type == "SPGC") & x.passes_geno_positions & x.passes_in_sample_list:
        return "sfacts_only"
    elif (x.genome_type == "SPGC") & ~(
        x.passes_geno_positions & x.passes_in_sample_list
    ):
        return "sfacts_fails_qc"
    else:
        raise ValueError("Genome did not match classification criteria:", x)

In [None]:
filt_stats = []
missing_species = []

_species_list = species_list
# _species_list = ["100003"]

for species in tqdm(_species_list):
    inpath = f"data/group/hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_v15-v22-agg75.spgc-fit.strain_meta_spgc_and_ref.tsv"
    if not os.path.exists(inpath):
        missing_species.append(species)
        continue
    data = pd.read_table(inpath).assign(species=species, inpath=inpath)
    filt_stats.append(data)
filt_stats = pd.concat(filt_stats).assign(
    genome_class=lambda x: x.apply(classify_genome, axis=1)
)

print(
    len(missing_species),
    "out of",
    len(_species_list),
    "species are missing stats.",
)

In [None]:
# Define different subsets of the species:

# All species:
# species_list

# All species with enough positions
species_list0 = filt_stats[lambda x: x.passes_geno_positions].species.unique()

# All species with sf strains
species_list1 = filt_stats[
    lambda x: x.passes_geno_positions & x.genome_type.isin(["SPGC"])
].species.unique()

# All species with sf strains to talk about distributions (>=10)
species_list1b = idxwhere(
    filt_stats[
        lambda x: x.passes_geno_positions & x.genome_type.isin(["SPGC"])
    ].species.value_counts()
    >= 10
)

# All species with spgc strains
species_list2 = filt_stats[
    lambda x: x.passes_filter & x.genome_type.isin(["SPGC"])
].species.unique()

# All species with enough spgc strains for pangenome analysis (>=10)
species_list3 = idxwhere(
    filt_stats[
        lambda x: x.passes_filter & x.genome_type.isin(["SPGC"])
    ].species.value_counts()
    >= 10
)

# Species with large numbers of strains (>=20)
species_list4 = idxwhere(
    filt_stats[
        lambda x: x.passes_filter & x.genome_type.isin(["SPGC"])
    ].species.value_counts()
    >= 20
)

_species_list_map = {
    "All considered species": species_list,
    "Species with enough genotyped positions": species_list0,
    "With sfacts strains": species_list1,
    "With (>=10) sfacts strains": species_list1b,
    "With SPGC inferences": species_list2,
    "With >=10 inferences": species_list3,
    "With >=20 inferences": species_list4,
}
for _species_list_name, _species_list in _species_list_map.items():
    print(_species_list_name, len(_species_list))
    print(species_taxonomy.loc[_species_list].p__.value_counts())
    print()

## Better phylum palette

In [None]:
species_taxonomy.loc[species_list1].p__.value_counts()

In [None]:
phylum_order = [
    "p__Euryarchaeota",
    "p__Thermoplasmatota",
    "p__Firmicutes",
    "p__Firmicutes_A",
    "p__Firmicutes_C",
    # "p__Firmicutes_B", # None in species_list1
    # "p__Firmicutes_G", # B/G/I not sure how related to C or A
    # "p__Firmicutes_I", #
    # "p__Cyanobacteria", # None in species_list1
    "p__Actinobacteriota",
    "p__Synergistota",
    "p__Fusobacteriota",
    "p__Campylobacterota",
    "p__Proteobacteria",
    "p__Desulfobacterota_A",
    "p__Bacteroidota",
    "p__Verrucomicrobiota",
    # "dummy0", # 18
    # "dummy1", # 19
    # "dummy2", # 20
]

phylum_palette = lib.plot.construct_ordered_palette(
    phylum_order,
    cm="rainbow",
    desaturate_levels=[1.0, 0.5],
)

for p__ in phylum_order:
    print(p__, phylum_palette[p__])
    plt.scatter([], [], color=phylum_palette[p__], label=p__)
plt.legend(ncols=4)
lib.plot.hide_axes_and_spines()

# assert len(set(phylum_palette.values())) == len((phylum_palette.values()))

## Prevalences

In [None]:
spgc_gene_prevalence = []
missing_species = []

_species_list = species_list3

for species in tqdm(_species_list):
    inpath = f"data/group/hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_v15-v22-agg75.spgc-fit.uhgg-strain_gene.prevalence.tsv"
    if not os.path.exists(inpath):
        missing_species.append(species)
        continue
    data = pd.read_table(inpath, names=['gene_id', 'prevalence'], index_col='gene_id').prevalence
    spgc_gene_prevalence.append(data)
spgc_gene_prevalence = pd.concat(spgc_gene_prevalence).sort_index()

print(
    len(missing_species),
    "out of",
    len(_species_list),
    "species are missing stats.",
)

In [None]:
ref_gene_prevalence = []
missing_species = []

_species_list = species_list3

for species in tqdm(_species_list):
    inpath = f"data/species/sp-{species}/midasdb.gene75_v15.uhgg-strain_gene.prevalence.tsv"
    if not os.path.exists(inpath):
        missing_species.append(species)
        continue
    data = pd.read_table(inpath, names=['gene_id', 'prevalence'], index_col='gene_id').prevalence
    ref_gene_prevalence.append(data)
ref_gene_prevalence = pd.concat(ref_gene_prevalence).sort_index()

print(
    len(missing_species),
    "out of",
    len(_species_list),
    "species are missing stats.",
)

In [None]:
exclude_genes_never_greater_than = 0.01
d = pd.DataFrame(dict(ref=ref_gene_prevalence, spgc=spgc_gene_prevalence)).fillna(0)[lambda x: x.max(1) > exclude_genes_never_greater_than]

fig, ax = plt.subplots(figsize=(6.5, 5))
bins=np.linspace(0, 1, num=51)
*_, art = ax.hist2d('ref', 'spgc', data=d, bins=bins, norm=mpl.colors.SymLogNorm(1, vmin=1))
fig.colorbar(art, ax=ax, label='Genes (count)')

ax.set_aspect(1)
ax.set_xlabel('Reference Prevalence')
ax.set_ylabel('SPGC Prevalence')

ax.annotate('core (≥90%)', xy=(0.5, 0.901), ha='center', va='bottom', color='white')
ax.axhline(0.9, lw=1, linestyle='--', color='white')
ax.annotate('shell (15-90%)', xy=(0.5, 0.151), ha='center', va='bottom', color='white')
ax.axhline(0.15, lw=1, linestyle='--', color='white')
ax.annotate('cloud (<15%)', xy=(0.5, 0.0), ha='center', va='bottom', color='white')

sp.stats.pearsonr(d.ref, d.spgc)

## Genome Fractions

In [None]:
# TODO: Gather genome fractions for strains
# Filter strains

spgc_prevalence_class_counts = []
missing_species = []

_species_list = species_list2

for species in tqdm(_species_list):
    inpath = f"data/group/hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_v15-v22-agg75.spgc-fit.uhgg-strain_gene.prevalence_class_fraction.tsv"
    if not os.path.exists(inpath):
        missing_species.append(species)
        continue

    strain_list = filt_stats[lambda x: (x.species == species) & (x.passes_filter) & (x.genome_type == 'SPGC')].genome_id.unique()
    data = pd.read_table(inpath, index_col='strain').rename(index=str).loc[strain_list].assign(species=species).reset_index().set_index(['species', 'strain'])
    spgc_prevalence_class_counts.append(data)
spgc_prevalence_class_counts = pd.concat(spgc_prevalence_class_counts).sort_index()

print(
    len(missing_species),
    "out of",
    len(_species_list),
    "species are missing stats.",
)

In [None]:
spgc_prevalence_class_counts.groupby('species').median().apply(lambda x: x / x.sum(), axis=1)

In [None]:
median_prevalence_class_fraction = spgc_prevalence_class_counts.groupby('species').median().apply(lambda x: x / x.sum(), axis=1)

In [None]:
median_prevalence_class_fraction.assign(c=lambda x: x.index.to_series().map(species_taxonomy.p__)).sort_values('c')

In [None]:
fig = plt.figure(figsize=(8, 8), facecolor='none')
ax = fig.add_subplot(projection="ternary", ternary_sum=100.0, rotation=180)
ax.grid()


ax.set_tlabel("Core (%)")
ax.set_llabel("Shell (%)")
ax.set_rlabel("Cloud (%)")

# ax.set_tlim(20, 100)
# ax.set_llim(10, 62)
# ax.set_rlim(0, 35)

d0 = median_prevalence_class_fraction.loc[species_list3].assign(
    p__=lambda x: x.index.to_series().map(species_taxonomy.p__),
    num_genomes=spgc_prevalence_class_counts.reset_index()['species'].value_counts(),
)

_, ax_legend = plt.subplots()

for p__ in phylum_order:
    d1 = d0[lambda x: x.p__ == p__]
    ax.scatter(
                "core",
                "shell",
                "cloud",
                data=d1,
                color=phylum_palette[p__],
                facecolor='none',
                s=15,
                label='__nolegend__',
                # marker="o",
                # lw=2,
                # facecolor="none",
                # alpha=0.85,
            )
    ax_legend.scatter([], [], color=phylum_palette[p__],
                facecolor='none', label=p__, s=50, lw=3)
ax_legend.legend(bbox_to_anchor=(1, 0.5))
lib.plot.hide_axes_and_spines(ax_legend)

In [None]:
_reduced_phylum_list = species_taxonomy.loc[species_list3].p__.unique()
_reduced_phylum_order = [p__ for p__ in phylum_order if p__ in _reduced_phylum_list]
p__meta = pd.DataFrame([], index=_reduced_phylum_order).assign(pos=lambda x: np.arange(len(x)))


fig, ax = plt.subplots(figsize=(5, 1), facecolor='none')
ax.set_xticks(p__meta.pos.unique())
ax.set_xticklabels(p__meta.index.to_series().str.replace('p__', ''), fontdict=dict(weight='heavy'))
for xtick, p__, c in zip(ax.get_xticklabels(), p__meta.index.to_series(), p__meta.index.to_series().map(phylum_palette)):
    print(xtick, p__, c)
    xtick.set_color(c)

ax.set_xlabel('Phylum')

# # sns.stripplot(x='p__', y='Ref_branch_frac', data=d)

lib.plot.rotate_xticklabels(ax=ax, rotation=35)

## Enrichment

In [None]:
cog_category_description = pd.read_table(
    "ref/cog-20.categories.tsv",
    names=["cog_category", "color", "description"],
    index_col="cog_category",
).description
cog_category_description.sort_index()

In [None]:
cog_category = []
for species in tqdm(species_list3):
    cog_category_inpath = f'data/species/sp-{species}/midasdb_v15.emapper.gene75_x_cog_category.tsv'
    cog_category.append(pd.read_table(cog_category_inpath))

cog_category = pd.concat(cog_category)

In [None]:
gene_x_cog_category_matrix = cog_category.set_index(['centroid_75', 'cog_category']).assign(annotation=True).unstack('cog_category', fill_value=False).annotation
gene_x_cog_category_matrix.sum()

In [None]:
spgc_prevalence_class = spgc_gene_prevalence.map(
    lambda x: np.where(x > 0.9, "core", np.where(x > 0.15, "shell", np.where(x > 0, "cloud", "absent")))
)

In [None]:
d0 = (
    spgc_prevalence_class.to_frame("prevalence_class")
    .join(gene_x_cog_category_matrix)
    .assign(
        cloud=lambda x: x.prevalence_class == "cloud",
        shell=lambda x: x.prevalence_class == "shell",
        core=lambda x: x.prevalence_class == "core",
        absent=lambda x: x.prevalence_class == "absent",
    )
)[lambda x: ~x.absent]

result = []
for _prevalence_class, _cog_category in tqdm(
    list(product(["core", "shell", "cloud", "absent"], gene_x_cog_category_matrix.columns))
):
    d1 = (
        d0[[_prevalence_class, _cog_category]]
        .value_counts()
        .unstack()
        .reindex(index=[True, False], columns=[True, False])
        .fillna(0)
    )
    d1_pc = d1 + 1
    log_oddsratio = np.log2(
        (d1_pc.loc[True, True] / d1_pc.loc[True, False])
        / (d1_pc.loc[False, True] / d1_pc.loc[False, False])
    )
    result.append(
        (_prevalence_class, _cog_category, d1.loc[True, True], log_oddsratio, sp.stats.fisher_exact(d1)[1])
    )
prevalence_class_cog_category_enrichment = pd.DataFrame(
    result, columns=["prevalence_class", "cog_category", "num_genes", "log2_oddsratio", "pvalue"]
).set_index(["prevalence_class", "cog_category"])

In [None]:
d = prevalence_class_cog_category_enrichment

d_oddsr = d.log2_oddsratio.unstack("prevalence_class")
d_signf = d.pvalue.map(lambda x: np.where(x < 0.05, "·", "")).unstack(
    "prevalence_class"
)

prevalence_class_order = ["core", "shell", "cloud", "absent"]
cog_category_order = d_oddsr["core"].sort_values(ascending=False).index

fig, ax = plt.subplots(figsize=(10, 10))
sns.heatmap(
    d_oddsr.loc[cog_category_order, prevalence_class_order],
    norm=mpl.colors.PowerNorm(1, vmin=-2, vmax=+2),
    cmap="coolwarm",
    ax=ax, cbar_kws=dict(extend='both'),
)
# Annotations (because seaborn annotations are failing)
for (i, _cog_category), (j, _prevalence_class) in product(
    enumerate(cog_category_order), enumerate(prevalence_class_order)
):
    ax.annotate(
        d_signf.loc[_cog_category, _prevalence_class],
        xy=(j + 0.5, i + 0.5),
        ha="center",
        va="center",
    )

In [None]:
cog_category_label = {
    'J': 'Ribosomes / Translation - J',
    'A': 'RNA Processing - A',
    'K': 'Transcription - K',
    'L': 'DNA replication/recombination/repair - L',
    'B': 'Chromatin - B',
    'D': 'Cell cycle control - D',
    'Y': 'Nucleus - Y',
    'V': 'Defense - V',
    'T': 'Signal transduction - T',
    'M': 'Cell envelope - M',
    'N': 'Motility - N',
    'Z': 'Cytoskeleton - Z',
    'W': 'Extracellular structures - W',
    'U': 'Secretion / vesicular transport - U',
    'O': 'Protein processing - O',
    'X': '"Mobilome" - X',
    'C': 'Energy - C',
    'G': 'Carbohydrates - G',
    'E': 'Amino acids - E',
    'F': 'Nucleotides - F',
    'H': 'Coenzymes - H',
    'I': 'Lipids - I',
    'P': 'Inorganic ions - P',
    'Q': 'Secondary metabolites - Q',
    'R': 'General only - R',
    'S': 'TODO: This shouldn\'t show up',
    'no_category': 'Unknown',
}

In [None]:
prevalence_class_cog_category_enrichment.groupby('cog_category').num_genes.sum().sort_values()

In [None]:
d = prevalence_class_cog_category_enrichment

cog_category_order = d.xs('core').log2_oddsratio.sort_values(ascending=True).index
cog_category_idx = pd.Series(np.arange(len(cog_category_order)), index=cog_category_order).rename_axis('cog_category')
prevalence_class_order = ['core', 'shell', 'cloud']
prevalence_class_idx = pd.Series(np.arange(len(prevalence_class_order)), index=prevalence_class_order).rename_axis('prevalence_class')

num_genes_to_size = lambda x: 55 * np.log(x + 1)
signif_size = 20

d = d.join(prevalence_class_idx.rename('prevalence_class_idx')).join(cog_category_idx.rename('cog_category_idx')).assign(num_genes_s=lambda x: num_genes_to_size(x.num_genes), signif=lambda x: signif_size * (x.pvalue >= 0.05))

vmin, vmax = -2, 2

fig, ax = plt.subplots(figsize=(2, 12))
ax.scatter(x='prevalence_class_idx', y='cog_category_idx', data=d, c='log2_oddsratio', s='num_genes_s', cmap='coolwarm', norm=mpl.colors.PowerNorm(1, vmin=vmin, vmax=vmax), label='__nolegend__')
ax.scatter(x='prevalence_class_idx', y='cog_category_idx', data=d, s='signif', color='k', marker='x', label='__nolegend__', lw=1)
# for _, d1 in d.iterrows():
#     ax.annotate(d1.signif, xy=(d1.prevalence_class_idx, d1.cog_category_idx), ha='center', va='center')



ax.set_xlim(-0.5, len(prevalence_class_order) - 0.5)
ax.set_ylim(-1.0, len(cog_category_order))
ax.set_xticks(prevalence_class_idx)
ax.set_xticklabels(prevalence_class_order)
ax.set_yticks(cog_category_idx)
ax.set_yticklabels([cog_category_label[c] for c in cog_category_order])
lib.plot.rotate_xticklabels(ax=ax)

# Remove frame
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)

for log2_oddsratio in np.linspace(-2, 2, num=5):
    ax.scatter([], [], color=mpl.cm.coolwarm((log2_oddsratio - vmin) / (vmax - vmin)), label=log2_oddsratio)
for num_genes in np.logspace(0, 4, num=5):
    ax.scatter([], [], color='black', label=num_genes, s=num_genes_to_size(num_genes))
ax.legend(bbox_to_anchor=(1, 1))
# lib.plot.hide_axes_and_spines(ax)

In [None]:
d
prevalence_class_x = ["core", "shell", "cloud"]
cog_category_order = d_oddsr["core"].sort_values(ascending=False).index

# Sort all matrices:
