## Preamble

In [None]:
!date

### Template Utils

In [None]:
%load_ext autoreload
%load_ext line_profiler

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

### Imports

In [None]:
import os
import subprocess
import time
from itertools import chain, product
from tempfile import mkstemp

import fastcluster
import matplotlib as mpl
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import sfacts as sf
import statsmodels as sm
import statsmodels.formula.api as smf
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.spatial.distance import pdist, squareform
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics import adjusted_mutual_info_score, adjusted_rand_score
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
import lib.thisproject.data
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

In [None]:
def morans_i(x, w, return_zscore=False, centered=False):
    """TODO

    w should be a weight matrix, e.g.

        w = np.exp(-(strain_geno_pdist.loc[_genome_list, _genome_list].values))
        w[np.diag_indices_from(w)] = 0

    """
    # Moran's I.
    x_centered = x - x.mean()
    denomenator = (x_centered**2).sum()
    if denomenator == 0:
        return np.nan

    n = len(x)
    w_sum = w.sum()

    numerator = np.einsum("i,j,ij->", x_centered, x_centered, w)
    normalizer = n / w_sum
    observed = numerator * normalizer / denomenator

    # Expected value of Moran's I under the null hypothesis:
    expected = -1 / (len(x) - 1)
    if centered:
        observed_out = observed - expected
    else:
        observed_out = observed

    if return_zscore:
        # Expected variance under the null
        s1 = 0.5 * ((w + w.T) ** 2).sum()
        s2 = ((w.sum(0) + w.sum(1)) ** 2).sum()
        s3 = ((1 / n) * (x_centered**4).sum()) / (
            (1 / n) * x_centered**2
        ).sum() ** 2
        s4 = (n**2 - 3 * n + 3) * s1 - n * s2 + 3 * w_sum**2
        s5 = (n**2 - n) * s1 - 2 * n * s2 + 6 * w_sum**2

        i_var = (n * s4 - s3 * s5) / (
            (n - 1) * (n - 2) * (n - 3) * w_sum**2
        ) - expected**2

        zscore = (observed - expected) / np.sqrt(i_var)
        return observed_out, zscore
    else:
        return observed_out
    # pvalue = 1 - sp.stats.norm(loc=0, scale=1).cdf(zscore)

In [None]:
def geno_pdist_to_weights(pdist, rate=2):
    pdist = np.array(pdist)
    w = np.exp(-rate * pdist)
    w[np.diag_indices_from(w)] = 0
    return w

In [None]:
from numba import jit


# @jit(nopython=True, fastmath=False, nogil=True)
def genotype_dissimilarity_transformed_values_jit(x, y):
    dist = np.abs((x - y) / 2)
    weight = np.abs(x * y)
    wmean_dist = (weight * dist).sum() / weight.sum()
    # NOTE: Why not finish up by powering it by (1 / q)?
    # I don't do this part because it loses the city-block distance
    # interpretation when x and y are both discrete (i.e. one of {0, 1}).

    # While the basic function is undefined where weight.sum() == 0
    # (and this is only true when one of x or y is always exactly 0.5 at every
    # index),
    # the limit approaches the same value from both directions.
    # We therefore redefine the dissimilarity as a piecewise function,
    # but one that is nonetheless everywhere smooth and defined.
    if np.isnan(wmean_dist):
        return dist.mean()
    return wmean_dist  # np.where(np.isnan(wmean_dist), dist.mean(), wmean_dist)

In [None]:
def linkage_order(linkage, labels):
    labels = np.array(labels)
    return list(labels[sp.cluster.hierarchy.to_tree(linkage).pre_order(lambda x: x.id)])

### Set Style

In [None]:
sns.set_context("talk")
plt.rcParams["figure.dpi"] = 50

### Papermill parameters

In [None]:
# This cell is tagged "parameters" for papermill.
# See <https://papermill.readthedocs.io/en/latest/usage-parameterize.html#how-parameters-work> for some gotchas.
# NOTE: *ALL* parameters should be passed to papermill. Values set here are only for prototyping.
species_id = "100003"
show_unimportant_figures = "True"  # Since this is coming in through papermill, I'll need to treat it as a boolean below, and the only False string is "".

group = "xjin_ucfmt_hmp2"
spgc_ss_stem = 'all'
centroidA = '99'
centroidB = '75'
spgc_specgene_stem = "spgc_specgene-ref2-p95"
pgene_stem = f"gene{centroidA}_new-v22-agg{centroidB}"
spgc_stem = f"{spgc_specgene_stem}_ss-{spgc_ss_stem}_t-30_thresh-corr150-depth250"
sfacts_stem = "filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts37-s85-seed0"

species_taxonomy_inpath = "ref/gtpro/species_taxonomy_ext.tsv"
sample_to_spgc_inpath = f"data/group/{group}/species/sp-{species_id}/r.proc.gtpro.{sfacts_stem}.spgc_ss-{spgc_ss_stem}.strain_samples.tsv"
sfacts_fit_inpath = f"data/group/{group}/species/sp-{species_id}/r.proc.gtpro.{sfacts_stem}.world.nc"
spgc_agg_mgtp_inpath = f"data/group/{group}/species/sp-{species_id}/r.proc.gtpro.{sfacts_stem}.spgc_ss-{spgc_ss_stem}.strain_mgtp.nc"
ref_geno_inpath = (
    f"data/species/sp-{species_id}/midasdb.geno.nc"  # "Re-calculated GT-Pro genotype"
)
# ref_geno_inpath = (
#     f"data/species/sp-{species_id}/gtpro_ref.mgtp.nc"  # "Reference GT-Pro genotype"
# )
spgc_meta_inpath = f"data/group/{group}/species/sp-{species_id}/r.proc.gtpro.{sfacts_stem}.{pgene_stem}.{spgc_stem}.strain_meta.tsv"
ref_gene_copy_number_uhgg_inpath = (
    f"data/species/sp-{species_id}/gene{centroidB}_new.reference_copy_number.nc"
)
spgc_gene_uhgg_inpath = f"data/group/{group}/species/sp-{species_id}/r.proc.gtpro.{sfacts_stem}.{pgene_stem}.{spgc_stem}.strain_gene.tsv"
spgc_gene_uhgg_depth_inpath = f"data/group/{group}/species/sp-{species_id}/r.proc.gtpro.{sfacts_stem}.{pgene_stem}.{spgc_specgene_stem}_ss-{spgc_ss_stem}_t-30.strain_depth_ratio.tsv"
spgc_gene_uhgg_corr_inpath = f"data/group/{group}/species/sp-{species_id}/r.proc.gtpro.{sfacts_stem}.{pgene_stem}.{spgc_specgene_stem}_ss-{spgc_ss_stem}_t-30.strain_correlation.tsv"
uhgg_x_eggnog_inpath = (
    f"data/species/sp-{species_id}/pangenome_new.centroids.emapper.gene_x_eggnog.tsv"
)
uhgg_x_top_eggnog_inpath = f"data/species/sp-{species_id}/pangenome_new.centroids.emapper.gene_x_top_eggnog.tsv"
uhgg_gene_length_inpath = (
    f"ref/midasdb_uhgg_new/pangenomes/{species_id}/cluster_info.txt"
)

# TODO: Add these to the snakemake rule.
geno_pdist_inpath = f"data/group/{group}/species/sp-{species_id}/r.proc.gtpro.{sfacts_stem}.spgc_ss-{spgc_ss_stem}.geno_uhgg-{dvm}_pdist-mask10-pseudo10.pkl"
gene_pdist_inpath = f"data/group/{group}/species/sp-{species_id}/r.proc.gtpro.{sfacts_stem}.{pgene_stem}.{spgc_stem}.gene_filt5_jaccard_pdist.pkl"

gene_annotations_inpath = f"data/species/sp-{species_id}/pangenome_new.centroids.emapper.d/proteins.emapper.annotations"
uhgg_depth_inpath = f"data/group/{group}/species/sp-{species_id}/r.proc.{pgene_stem}.depth2.nc"
species_depth_inpath = f"data/group/{group}/species/sp-{species_id}/r.proc.{pgene_stem}.{spgc_specgene_stem}.species_depth.tsv"

mgen_inpath = "meta/hmp2/mgen.tsv"
preparation_inpath = "meta/hmp2/preparation.tsv"
stool_inpath = "meta/hmp2/stool.tsv"
subject_inpath = "meta/hmp2/subject.tsv"

# TODO: Export SPGC gene info: prevalence, passes filtering against references,
#    concordance with SNP dissimilarity in each genome set (both using SPGC-only, SPGC+refs, refs-only),
#    cluster membership for each clustering analysis (SPGC-only, SPGC+refs, refs-only),
gene_stats_outpath = f"data/group/{group}/species/sp-{species_id}/r.proc.gtpro.{sfacts_stem}.{pgene_stem}.{spgc_stem}.gene_stats.tsv"
# TODO: Export info for each SPGC strain: passes filtering? relationship between
#   minimum genotype distance to a reference and the associated minimum gene
#   distance (both weighted and unweighted)
spgc_strain_stats_outpath = f"data/group/{group}/species/sp-{species_id}/r.proc.gtpro.{sfacts_stem}.{pgene_stem}.{spgc_stem}.spgc_strain_stats.tsv"
# TODO: Export info for each reference strain: relationship between
#   minimum genotype distance to a reference and the associated minimum gene
#   distance (both weighted and unweighted)
ref_strain_stats_outpath = f"data/group/{group}/species/sp-{species_id}/r.proc.gtpro.{sfacts_stem}.{pgene_stem}.{spgc_stem}.ref_strain_stats.tsv"
# TODO: Export compiled metadata (length, COG category table), etc.
gene_meta_outpath = f"data/group/{group}/species/sp-{species_id}/r.proc.gtpro.{sfacts_stem}.{pgene_stem}.{spgc_stem}.gene_meta.tsv"

# TODO: Export anything else I need for cross-species comparisons. What might that be?
# strain_mwas_outpath = f"data/group/{group}/species/sp-{species_id}/r.proc.gtpro.{sfacts_stem}.{pgene_stem}.{spgc_stem}.strain_mwas.tsv"
html_outpath = f"data/group/{group}/species/sp-{species_id}/r.proc.gtpro.{sfacts_stem}.{pgene_stem}.{spgc_stem}.spgc_ref_comparison.html"

## Data Loading / Validation

#### Taxonomy

In [None]:
species_taxonomy = lib.thisproject.data.load_species_taxonomy(species_taxonomy_inpath)
species_taxonomy.loc[species_id]

### Gene Annotations

In [None]:
uhgg_x_eggnog = pd.read_table(uhgg_x_eggnog_inpath)

In [None]:
uhgg_x_top_eggnog = pd.read_table(uhgg_x_top_eggnog_inpath)

In [None]:
uhgg_gene_length = (
    pd.read_table(uhgg_gene_length_inpath)
    .groupby("centroid_75")
    .centroid_99_length.mean()
)

In [None]:
eggnog_column_names = "query seed_ortholog evalue score eggNOG_OGs max_annot_lvl COG_category Description Preferred_name GOs EC KEGG_ko KEGG_inpathway KEGG_Module KEGG_Reaction KEGG_rclass BRITE KEGG_TC CAZy BiGG_Reaction PFAMs".split(
    " "
)
_gene_annotations = (
    pd.read_table(
        gene_annotations_inpath,
        comment="#",
        names=eggnog_column_names,
        index_col="query",
    )
    .rename_axis(index="gene_id")
    .replace({"-": np.nan})
)
_gene_annotations = uhgg_gene_length.to_frame().join(_gene_annotations)

_gene_annotations.info()

In [None]:
gene_x_cog_category1 = (
    _gene_annotations.COG_category.fillna("-").apply(list).explode()[lambda x: x != "-"]
)
gene_x_cog_category1

In [None]:
cog_x_category = pd.read_table(
    "ref/cog-20.meta.tsv",
    names=["cog", "cog_category", "description", "short_name", "_4", "_5", "_6"],
    index_col="cog",
).cog_category
cog_x_category

In [None]:
_gene_annotations.eggNOG_OGs

In [None]:
gene_x_cog = (
    _gene_annotations.eggNOG_OGs.fillna("")
    .str.split(",")
    .explode()[lambda x: x.str.startswith("COG")]
    .str.split("@")
    .str[0]
)
gene_x_cog.value_counts().head()
gene_x_cog_category2 = gene_x_cog.map(cog_x_category).dropna().apply(list).explode()
gene_x_cog_category2

In [None]:
gene_x_cog_category = (
    pd.concat(
        [
            gene_x_cog_category1,
            gene_x_cog_category2,  # FIXME: Which metadata table do I want?
        ]
    )
    .reset_index()
    .drop_duplicates()
)

gene_x_cog_category.columns = ["centroid_75", "cog_category"]
gene_x_cog_category = gene_x_cog_category.set_index("centroid_75").cog_category
gene_x_cog_category.shape[0]

In [None]:
gene_annotations = _gene_annotations.assign(
    COG_category=(
        gene_x_cog_category.sort_values()
        .reset_index()
        .groupby("centroid_75")
        .apply(lambda x: "".join(x.cog_category.values))
    )
).assign(COG_category=lambda x: x.COG_category.fillna(""))

In [None]:
gene_x_cog_category_matrix = (
    gene_x_cog_category.reset_index()
    .assign(tally=True)
    .set_index(["centroid_75", "cog_category"])
    .tally.unstack("cog_category")
    .fillna(False)
    .reindex(gene_annotations.index, fill_value=False)
    .assign(no_category=lambda x: x.sum(1) == 0)
)
gene_x_cog_category_matrix

### Samples

In [None]:
# TODO: Read in a list of relevant samples and filter to only strains found in these.

In [None]:
species_depth = pd.read_table(
    species_depth_inpath, names=["sample", "depth"], index_col=["sample"]
).depth
species_depth

In [None]:
mgen = pd.read_table(mgen_inpath, index_col="library_id")
preparation = pd.read_table(preparation_inpath, index_col="preparation_id")
stool = pd.read_table(stool_inpath, index_col="stool_id")
subject = pd.read_table(subject_inpath, index_col="subject_id")

mgen_meta = (
    mgen.join(preparation, on="preparation_id", lsuffix="_mgen", rsuffix="_preparation")
    .join(stool, on="stool_id")
    .join(subject, on="subject_id")
)

In [None]:
subject_palette = lib.plot.construct_ordered_palette(mgen_meta.subject_id.unique())
subject_palette

In [None]:
sample_list = mgen.index
assert sample_list.unique
sample_list = sample_list.to_list()
len(sample_list), sample_list[:5], sample_list[-5:]

In [None]:
uhgg_depth = xr.load_dataarray(uhgg_depth_inpath)
subject_uhgg_depth = (
    mgen_meta[["subject_id"]]
    .join(uhgg_depth.to_pandas())
    .groupby("subject_id")
    .mean()
    .dropna()
)

In [None]:
uhgg_depth_ratio = uhgg_depth / species_depth.to_xarray().sel(sample=uhgg_depth.sample)

In [None]:
subject_mean_species_depth = (
    mgen_meta[["subject_id"]]
    .join(species_depth)
    .groupby("subject_id")
    .depth.mean()
    .dropna()
)

In [None]:
total_subject_uhgg_depth = (
    mgen_meta[["subject_id"]]
    .join(uhgg_depth.to_pandas())
    .groupby("subject_id")
    .sum()
    .dropna()
)
total_subject_species_depth = (
    mgen_meta[["subject_id"]]
    .join(species_depth)
    .groupby("subject_id")
    .depth.sum()
    .dropna()
)
subject_uhgg_depth_ratio = total_subject_uhgg_depth.divide(
    total_subject_species_depth, axis=0
)

### Strains

#### Ref Strains

In [None]:
# "Re-calculated GT-Pro genotype"
ref_geno = xr.load_dataarray(ref_geno_inpath)
ref_geno["genome_id"] = ref_geno.genome_id.to_series().map(
    lambda s: "UHGG" + s[len("GUT_GENOME") :]
)
ref_geno = ref_geno.fillna(0.5)
ref_geno = sf.data.Genotype(ref_geno.rename({"genome_id": "strain"}).T)

In [None]:
reference_meta = (
    pd.read_table("ref/uhgg_genomes_all_4644.tsv", index_col="Genome")
    .rename_axis(index="genome_id")[
        lambda x: x.MGnify_accession == "MGYG-HGUT-" + species_id[1:]
    ]
    .rename(lambda s: "UHGG" + s[len("GUT_GENOME") :])
    .loc[ref_geno.strain]
)

In [None]:
bins = np.linspace(0, 100, num=101)
plt.hist(reference_meta.Completeness, bins=bins)
plt.hist(reference_meta.Contamination, bins=bins)
None

In [None]:
ref_gene_copy_number_uhgg = xr.load_dataarray(ref_gene_copy_number_uhgg_inpath)
ref_gene_uhgg = (ref_gene_copy_number_uhgg > 0).astype(int).to_pandas().T

In [None]:
ref_gene_uhgg.shape

#### SPGC Strains

In [None]:
sample_to_spgc = pd.read_table(sample_to_spgc_inpath, index_col="sample").strain.astype(
    str
)

In [None]:
strain_x_sample_list_count = (
    sample_to_spgc.to_frame()
    .assign(in_sample_list=lambda x: x.index.isin(sample_list))
    .value_counts()
    .unstack("in_sample_list", fill_value=0)
    .reindex(columns=[False, True], fill_value=0)
    .rename(
        columns={
            True: "num_samples_in_sample_list",
            False: "num_samples_not_in_sample_list",
        }
    )
)

In [None]:
sfacts_fit = sf.World.load(sfacts_fit_inpath).drop_low_abundance_strains(0.5)
# Strains should be str not int.
sfacts_fit.data["strain"] = sfacts_fit.strain.values.astype(str)
print(dict(sfacts_fit.sizes))

# Pre-calculate shared heatmap decorations
position_ss = sfacts_fit.random_sample(
    position=min(500, sfacts_fit.sizes["position"])
).position

# Construct a strain palette
_world = sfacts_fit.drop_low_abundance_strains(0.05)
_sfacts_list = list(
    linkage_order(
        _world.genotype.linkage(optimal_ordering=True),
        _world.strain.values,
    )
)
_sfacts_list.remove("-1")  # Drop "other" strain.
spgc_palette = lib.plot.construct_ordered_palette(
    _sfacts_list,
    cm="rainbow",
)

In [None]:
spgc_est_mgtp = sf.Metagenotype.load(spgc_agg_mgtp_inpath)
spgc_est_mgtp.data["sample"] = spgc_est_mgtp.data["sample"].to_series().apply(str)
spgc_est_geno = spgc_est_mgtp.to_estimated_genotype(pseudo=0)
spgc_est_mgtp.sizes

In [None]:
unifrac_pdist = sfacts_fit.unifrac_pdist()
unifrac_cdmat = sp.spatial.distance.squareform(unifrac_pdist)
spgc_sample_linkage = sp.cluster.hierarchy.optimal_leaf_ordering(
    fastcluster.linkage(unifrac_cdmat), unifrac_cdmat
)

In [None]:
w = sfacts_fit.sel(position=position_ss)
try:
    spgc_position_ss_linkage = fastcluster.linkage(
        sp.spatial.distance.squareform(
            spgc_est_geno.sel(position=position_ss).pdist("position")
        )
    )
except ValueError as err:
    print(err)
    spgc_position_ss_linkage = None
spgc_sample_colors = (
    sfacts_fit.sample.to_series().map(sample_to_spgc).map(spgc_palette)
)  # Color samples that are strain-pure.
spgc_strain_linkage = fastcluster.linkage(
    sp.spatial.distance.squareform(w.genotype.pdist("strain"))
)
spgc_strain_colors = sfacts_fit.strain.to_series().map(spgc_palette)

In [None]:
if show_unimportant_figures:
    sf.plot.plot_community(
        sfacts_fit.sel(position=position_ss),
        col_linkage_func=lambda w: spgc_sample_linkage,
        row_linkage_func=lambda w: spgc_strain_linkage,
        col_colors=spgc_sample_colors,
        row_colors=spgc_strain_colors,
    )

In [None]:
if show_unimportant_figures:
    sf.plot.plot_metagenotype(
        sfacts_fit.sel(position=position_ss),
        col_linkage_func=lambda w: spgc_sample_linkage,
        row_colors=spgc_sample_colors,
        row_linkage_func=lambda w: spgc_position_ss_linkage,
        transpose=True,
    )

In [None]:
if show_unimportant_figures:
    g = sf.data.Genotype.concat(
        dict(mgen=spgc_est_geno, fit=sfacts_fit.genotype), dim="strain"
    ).mlift("fillna", 0.5)
    g_pdist = g.pdist()
    g_linkage = g.linkage()

    sf.plot.plot_genotype(
        g.sel(position=position_ss),
        # transpose=True,
        row_linkage_func=lambda w: g_linkage,
        col_linkage_func=lambda w: spgc_position_ss_linkage,
    )

In [None]:
spgc_meta = (
    pd.read_table(spgc_meta_inpath, index_col="genome_id")
    .rename_axis(index="strain")
    .rename(str)
    .join(strain_x_sample_list_count)
)
print(spgc_meta.shape)
spgc_meta

In [None]:
assert spgc_meta.assign(
    match_counts=lambda x: (
        x.num_samples_in_sample_list + x.num_samples_not_in_sample_list
    )
    == x.num_sample
).match_counts.all()

In [None]:
spgc_gene_uhgg = pd.read_table(spgc_gene_uhgg_inpath, index_col="gene_id").rename_axis(
    columns="strain"
)

In [None]:
ref_num_genes_uhgg = ref_gene_uhgg.sum()
spgc_num_genes_uhgg = spgc_gene_uhgg.sum()

### Strain Selection / Filtering

In [None]:
# NOTE: Select any ref genotype that is within the top-10 closest distances from an SPGC strain.
# spgc_to_ref_geno_cdist = ref_geno.sel(
#     position=spgc_est_geno.position,
#     strain=idxwhere(
#         (reference_meta.Completeness > 97) & (reference_meta.Contamination < 2)
#     ),
# ).cdist(spgc_est_geno)
# ref_list = list(
#     spgc_to_ref_geno_cdist.apply(lambda x: x.sort_values().head(10).index)
#     .stack()
#     .unique()
# )

# Select all high-enough-quality refs
ref_list = idxwhere(
    (reference_meta.Completeness > 90) & (reference_meta.Contamination < 2)
)

print(
    len(ref_list),
    "of",
    len(reference_meta),
    "reference genomes pass completeness and contamination thresholds.",
)

In [None]:
reference_meta

In [None]:
_ref_uhgg_prevalence = ref_gene_uhgg[ref_list].mean(1)
ref_gene_uhgg_entropy = (-_ref_uhgg_prevalence * np.log2(_ref_uhgg_prevalence)).fillna(
    0
)
plt.hist(ref_gene_uhgg_entropy, bins=[0] + list(np.logspace(-3, 1)))
plt.xscale("symlog", linthresh=1e-3, linscale=0.1)
plt.yscale("log")

In [None]:
x = spgc_meta[lambda x: x.species_gene_frac > 0.9].num_genes
y = ref_num_genes_uhgg
_df, _loc, _scale = sp.stats.t.fit(x.values, fix_df=2)
_dist0 = sp.stats.t(_df, _loc, _scale)
_dist1 = sp.stats.norm(_loc, _scale)

thresh_max_num_uhgg_genes = _dist1.ppf(0.999)
thresh_min_num_uhgg_genes = _dist1.ppf(0.001)


bins = np.linspace(0, x.max() * 1.5, num=50)
xx = np.linspace(0, x.max() * 1.5, num=1000)

plt.hist(x, density=True, bins=bins, alpha=0.2)
plt.hist(y, density=True, bins=bins, alpha=0.2)

plt.plot(xx, _dist0.pdf(xx), color="k")
plt.plot(xx, _dist1.pdf(xx), color="k", linestyle="--")
plt.axvline(thresh_max_num_uhgg_genes, lw=1, linestyle="--", color="k")
plt.axvline(thresh_min_num_uhgg_genes, lw=1, linestyle="--", color="k")

In [None]:
thresh_min_num_uhgg_genes, thresh_max_num_uhgg_genes

In [None]:
plt.scatter(
    "sum_depth",
    "species_gene_frac",
    c=(spgc_meta.num_genes - _loc) / _scale,
    data=spgc_meta,
    norm=mpl.colors.PowerNorm(1 / 1, vmin=-4, vmax=4),
)
plt.xscale("log")
# plt.yscale("logit")
plt.axvline(1, lw=1, color="k", linestyle="--")
plt.axhline(0.9, lw=1, color="k", linestyle="--")
plt.colorbar()

In [None]:
plt.scatter(
    "max_depth",
    "sum_depth",
    c=1 - spgc_meta.species_gene_frac,
    data=spgc_meta,
    norm=mpl.colors.LogNorm(),
)
plt.plot([0, 1e2], [0, 1e2])
plt.xscale("log")
plt.yscale("log")
# plt.axvline(1, lw=1, color="k", linestyle="--")
# plt.axhline(0.9, lw=1, color="k", linestyle="--")
plt.colorbar()

In [None]:
# NOTE: Select SPGC strains that pass various filters
spgc_list = idxwhere(
    (spgc_meta.sum_depth > 1)
    & (spgc_meta.num_samples_in_sample_list > 0)
    & (spgc_meta.species_gene_frac > 0.9)
    & (spgc_num_genes_uhgg <= thresh_max_num_uhgg_genes)
    & (spgc_num_genes_uhgg >= thresh_min_num_uhgg_genes)
)

print(
    "Out of",
    spgc_meta.shape[0],
    "SPGC strains,",
    len(spgc_list),
    "passed QC. There are also",
    len(ref_list),
    "reference strains passing QC.",
)

In [None]:
print(
    (spgc_meta.loc[spgc_list].num_samples_not_in_sample_list > 0).sum(),
    "of these strains were also found in the excluded samples.\n",
)
for strain in idxwhere((spgc_meta.loc[spgc_list].num_samples_not_in_sample_list > 0)):
    print(strain, ":", idxwhere(sample_to_spgc == strain))

In [None]:
assert len(spgc_list) > 2

In [None]:
strain_geno = sf.Genotype.concat(
    dict(
        ref=ref_geno.sel(strain=ref_list, position=spgc_est_geno.position),
        spgc=spgc_est_geno.sel(strain=spgc_list),
    ),
    dim="strain",
    rename=False,
).mlift("fillna", 0.5)

In [None]:
# Estimate of how long it will take to run the full pdist operation (hours):
((strain_geno.sizes["strain"] / 200) ** 2 * 4.61) / 3600

In [None]:
# # Too slow!!!
# strain_geno_pdist = strain_geno.pdist(q=1)

# Instead we'll do something faster.
# For computational efficiency, we're downcasting to float32 and pre-transforming the genotypes into [-1, 1] space.
g = strain_geno  # .isel(strain=slice(0, 100))  # For testing purposes
x = sf.math.genotype_binary_to_sign(g.values).astype(np.float32)
# Then we'll run this on a numba-jitted version of sf.math.genotype_dissimilarity_transformed_values
strain_geno_pdist_old = pd.DataFrame(
    squareform(pdist(x, metric=genotype_dissimilarity_transformed_values_jit)),
    index=g.strain,
    columns=g.strain,
)

In [None]:
import pickle

with open(geno_pdist_inpath, "rb") as f:
    d = pickle.load(f)

_cdmat = d["cdmat"]
_labels = [l.replace("GUT_GENOME", "UHGG") for l in d["labels"]]

_pdmat = pd.DataFrame(squareform(_cdmat), index=_labels, columns=_labels).loc[
    g.strain, g.strain
]
plt.scatter(squareform(strain_geno_pdist_old), squareform(_pdmat), s=1)

In [None]:
strain_geno_pdist = (
    _pdmat  # Use the new masked hamming distance estimates instead of the old ones
)

In [None]:
strain_geno_cdmat = sp.spatial.distance.squareform(strain_geno_pdist)
# strain_geno_linkage = strain_geno.linkage(pdist_kwargs=dict(q=1), optimal_ordering=True)
strain_geno_linkage = sp.cluster.hierarchy.optimal_leaf_ordering(
    fastcluster.linkage(strain_geno_cdmat, method="average"), strain_geno_cdmat
)

In [None]:
genome_type = reference_meta.Genome_type.reindex(strain_geno.strain)
genome_type.loc[genome_type.index.isin(spgc_list)] = "SPGC"
genome_type_order = ["Isolate", "MAG", "SPGC"]
genome_type_palette = lib.plot.construct_ordered_palette(genome_type_order)

In [None]:
isolate_list = idxwhere(genome_type == "Isolate")
mag_list = idxwhere(genome_type == "MAG")

In [None]:
fig, ax = plt.subplots(figsize=(2, 1))
for _genome_type in genome_type_order:
    ax.scatter([], [], color=genome_type_palette[_genome_type], label=_genome_type)
ax.legend()

lib.plot.hide_axes_and_spines(ax)

In [None]:
fig, ax = plt.subplots(figsize=(2, 1))
for _genome_type in genome_type_order:
    ax.scatter([], [], color=genome_type_palette[_genome_type], label=_genome_type)
ax.legend(ncols=3)

lib.plot.hide_axes_and_spines(ax)

In [None]:
_colors = genome_type.map(genome_type_palette)
sns.clustermap(
    strain_geno_pdist,
    row_colors=_colors,
    col_colors=_colors,
    row_linkage=strain_geno_linkage,
    col_linkage=strain_geno_linkage,
    vmin=0,
    vmax=1,
    xticklabels=0,
    yticklabels=0,
    # figsize=(40, 40),
)

In [None]:
fig, ax = plt.subplots(figsize=(10, 2))
sns.heatmap([[]], vmin=0, vmax=1, ax=ax, cbar_kws=dict(location="bottom"))
lib.plot.hide_axes_and_spines(ax=ax)

In [None]:
from scipy.spatial.distance import squareform

bins = np.linspace(0, 1)
plt.hist(
    squareform(strain_geno_pdist.loc[ref_list, ref_list]),
    bins=bins,
    histtype="step",
    density=True,
    label="ref-to-ref",
)
plt.hist(
    strain_geno_pdist.loc[ref_list, spgc_list].values.flatten(),
    bins=bins,
    histtype="step",
    density=True,
    label="spgc-to-ref",
)
plt.hist(
    squareform(strain_geno_pdist.loc[spgc_list, spgc_list]),
    bins=bins,
    histtype="step",
    density=True,
    label="spgc-to-spgc",
)
plt.legend()

In [None]:
bins = np.linspace(0, 1, num=200)

_pdist = strain_geno_pdist + np.eye(len(strain_geno_pdist))

plt.hist(
    _pdist.loc[ref_list, ref_list].min(),
    bins=bins,
    histtype="step",
    cumulative=True,
    density=True,
    label="ref-to-ref",
)
plt.hist(
    _pdist.loc[ref_list, spgc_list].min(),
    bins=bins,
    histtype="step",
    cumulative=True,
    density=True,
    label="spgc-to-ref",
)
plt.hist(
    _pdist.loc[spgc_list, ref_list].min(),
    bins=bins,
    histtype="step",
    cumulative=True,
    density=True,
    label="ref-to-spgc",
)
plt.hist(
    _pdist.loc[spgc_list, spgc_list].min(),
    bins=bins,
    histtype="step",
    cumulative=True,
    density=True,
    label="spgc-to-spgc",
)

plt.legend()
# plt.axvline(0.43, lw=1, linestyle='--', color='k')
plt.xlabel("minimum distance")
plt.ylabel("cumulative fraction")

In [None]:
spgc_gene_uhgg.shape, ref_gene_uhgg.shape

In [None]:
strain_gene_uhgg = pd.concat(
    [ref_gene_uhgg[ref_list], spgc_gene_uhgg[spgc_list]], axis=1
).fillna(0)[lambda x: x.sum(1) > 0]
strain_gene_uhgg.shape

In [None]:
bins = np.linspace(0, 10000)
plt.hist(
    strain_gene_uhgg[ref_list].sum(),
    bins=bins,
    histtype="step",
    label="ref",
    density=True,
)
plt.hist(
    strain_gene_uhgg[spgc_list].sum(),
    bins=bins,
    histtype="step",
    label="spgc",
    density=True,
)
plt.legend()

None

### Strain Dereplication

In [None]:
strain_derep_clust = pd.Series(
    sp.cluster.hierarchy.fcluster(strain_geno_linkage, t=0.01, criterion="distance"),
    index=strain_geno.strain,
)

strain_derep_clust_stats = (
    strain_derep_clust.to_frame(name="clust")
    .assign(
        strain_type_flag=lambda x: x.index.isin(spgc_list) * 2
        + x.index.isin(ref_list) * 3,
        strain_type=lambda x: x.strain_type_flag.map(
            {
                2: "spgc",
                3: "ref",
                # 0: "neither??", 6: "both??"
            }
        ),
    )[["clust", "strain_type"]]
    .value_counts()
    .unstack()
    .reindex(columns=["neither??", "spgc", "ref", "both??"])
    .fillna(0)
    .assign(
        clust_type_flag=lambda x: (x.spgc > 0) * 2 + (x.ref > 0) * 3,
        clust_type=lambda x: x.clust_type_flag.map(
            {
                # 0: "neither??",
                2: "spgc",
                3: "ref",
                5: "both",
            }
        ),
    )
)

pd.DataFrame(
    dict(
        total_strains=strain_derep_clust_stats[["ref", "spgc"]].sum(),
        derep_strain=(strain_derep_clust_stats[["ref", "spgc"]] > 0).sum(),
        derep_membership=(strain_derep_clust_stats.clust_type.value_counts()),
    )
).reindex(["ref", "both", "spgc"]).fillna(0).astype(int)

In [None]:
d = (
    strain_derep_clust.to_frame(name="clust")
    .join(genome_type)
    .value_counts()
    .unstack(fill_value=0)
)

pd.DataFrame(
    dict(
        total_strain=d.sum().rename(lambda x: str([x])),
        derep_strain=(d > 0).sum().rename(lambda x: str([x])),
        derep_membership=d.apply(lambda x: idxwhere((x > 0)), axis=1)
        .value_counts()
        .rename(str),
    )
).fillna(0).astype(int)

## Strain Geno/Gene Spaces

In [None]:
spgc_gene_uhgg_depth = (
    pd.read_table(spgc_gene_uhgg_depth_inpath, index_col=["gene_id", "strain"])
    .depth.unstack("strain")
    .rename(columns=str)
)
spgc_gene_uhgg_corr = (
    pd.read_table(spgc_gene_uhgg_corr_inpath, index_col=["gene_id", "strain"])
    .correlation.unstack("strain")
    .rename(columns=str)
)

In [None]:
_ref = ref_gene_uhgg[ref_list]
_spgc = spgc_gene_uhgg_depth[spgc_list]
# The "depth" for the references is nominally set to 1, so we just concatenate the boolean ref_gene_uhgg.
strain_gene_uhgg_depth = (
    pd.concat([_ref, _spgc], axis=1).fillna(0).loc[strain_gene_uhgg.index]
)

In [None]:
_ref = ref_gene_uhgg[ref_list]
_spgc = spgc_gene_uhgg_corr[spgc_list]
# The "correlation" for the references is nominally set to 1, so we just concatenate the boolean ref_gene_uhgg.
strain_gene_uhgg_corr = (
    pd.concat([_ref, _spgc], axis=1).fillna(0).loc[strain_gene_uhgg.index]
)

### Gene Filtering

#### Enriched/Depleted in SPGC Strains

In [None]:
# NOTE: Using a pseudocount so that ratios between these two values are always defined.
ref_uhgg_prevalence = (strain_gene_uhgg[ref_list].sum(1) + 1) / (len(ref_list) + 1)
spgc_uhgg_prevalence = (strain_gene_uhgg[spgc_list].sum(1) + 1) / (len(spgc_list) + 1)

In [None]:
x = ref_uhgg_prevalence
y = spgc_uhgg_prevalence

print(sp.stats.pearsonr(x, y))

fig, axs = plt.subplots(2, figsize=(5, 10))

bins0 = np.linspace(0.0, 1.0, num=50)
axs[0].hist2d(x, y, bins=bins0, norm=mpl.colors.PowerNorm(1 / 3, vmin=0, vmax=1e3))

bins1 = np.linspace(0.1, 0.9, num=40)
axs[1].hist2d(x, y, bins=bins1, norm=mpl.colors.PowerNorm(1 / 3))
axs[1].set_xlabel("reference prevalence")
axs[1].set_ylabel("inferred prevalence")
None

In [None]:
bins = np.logspace(-4, 4)
plt.hist(
    (spgc_uhgg_prevalence / ref_uhgg_prevalence).replace({0: 1e-3, np.inf: 1e3}),
    bins=bins,
    alpha=0.5,
)
plt.hist(
    (ref_uhgg_prevalence / spgc_uhgg_prevalence).replace({0: 1e-3, np.inf: 1e3}),
    bins=bins,
    alpha=0.5,
)
plt.xscale("log")
plt.yscale("log")

In [None]:
spgc_extremely_enriched = idxwhere((spgc_uhgg_prevalence / ref_uhgg_prevalence) > 1000)
spgc_extremely_depleted = idxwhere((ref_uhgg_prevalence / spgc_uhgg_prevalence) > 1000)
spgc_very_enriched = idxwhere((spgc_uhgg_prevalence / ref_uhgg_prevalence) > 100)
spgc_very_depleted = idxwhere((ref_uhgg_prevalence / spgc_uhgg_prevalence) > 100)
spgc_highly_enriched = idxwhere((spgc_uhgg_prevalence / ref_uhgg_prevalence) > 10)
spgc_highly_depleted = idxwhere((ref_uhgg_prevalence / spgc_uhgg_prevalence) > 10)
spgc_enriched = idxwhere((spgc_uhgg_prevalence / ref_uhgg_prevalence) > 5)
spgc_depleted = idxwhere((ref_uhgg_prevalence / spgc_uhgg_prevalence) > 5)
spgc_similar = idxwhere(
    ((spgc_uhgg_prevalence / ref_uhgg_prevalence) < 5)
    & ((ref_uhgg_prevalence / spgc_uhgg_prevalence) < 5)
)

#### Gene Length / Singletons

In [None]:
for spgc_enrichment_class_list in [
    spgc_extremely_enriched,
    spgc_very_enriched,
    spgc_highly_enriched,
    spgc_enriched,
    spgc_similar,
    spgc_depleted,
    spgc_highly_depleted,
    spgc_very_depleted,
    spgc_extremely_depleted,
]:
    print(
        len(spgc_enrichment_class_list),
        uhgg_gene_length.loc[spgc_enrichment_class_list].mean(),
        uhgg_gene_length.loc[spgc_enrichment_class_list].std(),
    )

In [None]:
short_genes = idxwhere(uhgg_gene_length < 300)
singleton_genes = idxwhere(strain_gene_uhgg.sum(1) <= 1)

#### Final Filter

In [None]:
drop_list = spgc_highly_enriched + spgc_highly_depleted + short_genes + singleton_genes
print(
    len(spgc_highly_enriched),
    len(spgc_highly_depleted),
    len(short_genes),
    len(singleton_genes),
)

strain_uhgg_filt = strain_gene_uhgg.drop(index=drop_list, errors="ignore")
print(strain_gene_uhgg.shape[0], strain_uhgg_filt.shape[0])

In [None]:
bins = np.linspace(0, strain_gene_uhgg[spgc_list + ref_list].sum().max(), num=50)

plt.hist(
    strain_gene_uhgg[ref_list].sum(),
    bins=bins,
    histtype="stepfilled",
    label="ref (uhgg)",
    density=True,
    color="tab:blue",
    alpha=0.5,
)
plt.hist(
    strain_gene_uhgg[spgc_list].sum(),
    bins=bins,
    histtype="stepfilled",
    label="spgc (uhgg)",
    density=True,
    color="tab:orange",
    alpha=0.5,
)

plt.hist(
    strain_uhgg_filt[ref_list].sum(),
    bins=bins,
    histtype="step",
    label="ref (filtered uhgg)",
    density=True,
    linestyle="-",
    color="tab:blue",
    lw=2,
)
plt.hist(
    strain_uhgg_filt[spgc_list].sum(),
    bins=bins,
    histtype="step",
    label="spgc (filtered uhgg)",
    density=True,
    linestyle="-",
    color="tab:orange",
    lw=2,
)

# plt.hist(strain_gene_eggnog[ref_list].sum(), bins=bins, histtype='stepfilled', label='ref (eggnog)', density=True, color='tab:blue', alpha=0.3)
# plt.hist(strain_gene_eggnog[spgc_list].sum(), bins=bins, histtype='stepfilled', label='spgc (eggnog)', density=True, color='tab:orange', alpha=0.3)

# plt.hist(strain_gene_filt_eggnog[ref_list].sum(), bins=bins, histtype='step', label='ref (filtered eggnog)', density=True, linestyle='--', color='tab:blue', lw=1.5)
# plt.hist(strain_gene_filt_eggnog[spgc_list].sum(), bins=bins, histtype='step', label='spgc (filtered eggnog)', density=True, linestyle='--', color='tab:orange', lw=1.5)

plt.legend(bbox_to_anchor=(1, 1))

None

In [None]:
uhgg_filt_cdmat = sp.spatial.distance.pdist(strain_uhgg_filt, metric="cosine")
# uhgg_filt_pdist = pd.DataFrame(
#     sp.spatial.distance.squareform(uhgg_filt_cdmat),
#     index=strain_uhgg_filt.index,
#     columns=strain_uhgg_filt.index,
# )

In [None]:
# NOTE: Not using optimal ordering here. It takes way too long.
uhgg_filt_linkage = fastcluster.linkage(uhgg_filt_cdmat, method="average")

In [None]:
_strain_uhgg_filt_unweighted_jacc_cdmat = sp.spatial.distance.pdist(
    strain_uhgg_filt.T, metric="jaccard"
)
_strain_uhgg_filt_unweighted_jacc_pdist = pd.DataFrame(
    sp.spatial.distance.squareform(_strain_uhgg_filt_unweighted_jacc_cdmat),
    index=strain_uhgg_filt.columns,
    columns=strain_uhgg_filt.columns,
)

In [None]:
import pickle

with open(
    gene_pdist_inpath,
    "rb",
) as f:
    d = pickle.load(f)
_cdmat = d["cdmat"]
_labels = d["labels"]

_pdmat = pd.DataFrame(squareform(_cdmat), index=_labels, columns=_labels)
plt.scatter(
    squareform(_strain_uhgg_filt_unweighted_jacc_pdist),
    squareform(
        _pdmat.loc[
            _strain_uhgg_filt_unweighted_jacc_pdist.index,
            _strain_uhgg_filt_unweighted_jacc_pdist.columns,
        ]
    ),
)
plt.plot([0, 1], [0, 1], color="k")

In [None]:
strain_uhgg_filt_unweighted_jacc_pdist = _pdmat.loc[
    _strain_uhgg_filt_unweighted_jacc_pdist.index,
    _strain_uhgg_filt_unweighted_jacc_pdist.columns,
]
strain_uhgg_filt_unweighted_jacc_cdmat = squareform(
    strain_uhgg_filt_unweighted_jacc_pdist
)

In [None]:
# FIXME: Redundant? Delete this?
strain_uhgg_filt_unweighted_jacc_linkage = sp.cluster.hierarchy.optimal_leaf_ordering(
    fastcluster.linkage(strain_uhgg_filt_unweighted_jacc_cdmat, method="average"),
    strain_uhgg_filt_unweighted_jacc_cdmat,
)
# strain_uhgg_filt_unweighted_linkage = sp.cluster.hierarchy.optimal_leaf_ordering(
#     fastcluster.linkage(strain_uhgg_filt_unweighted_cdmat, method="average"),
#     strain_uhgg_filt_unweighted_cdmat,
# )

In [None]:
x = strain_uhgg_filt
_col_linkage = strain_geno_linkage
_row_linkage = uhgg_filt_linkage
# Order x by leaf order.
# See <https://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.hierarchy.ClusterNode.pre_order.html#scipy.cluster.hierarchy.ClusterNode.pre_order>
x = x.iloc[sp.cluster.hierarchy.to_tree(_row_linkage).pre_order(lambda x: x.id)]

_col_colors = genome_type.map(genome_type_palette)

if show_unimportant_figures:
    sns.clustermap(
        x, row_cluster=False, col_linkage=_col_linkage, col_colors=_col_colors
    )

In [None]:
x = strain_uhgg_filt
_col_linkage = strain_uhgg_filt_unweighted_jacc_linkage
_row_linkage = uhgg_filt_linkage
# Order x by leaf order.
# See <https://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.hierarchy.ClusterNode.pre_order.html#scipy.cluster.hierarchy.ClusterNode.pre_order>
x = x.iloc[sp.cluster.hierarchy.to_tree(_row_linkage).pre_order(lambda x: x.id)]

_col_colors = genome_type.map(genome_type_palette)

if show_unimportant_figures:
    sns.clustermap(
        x, row_cluster=False, col_linkage=_col_linkage, col_colors=_col_colors
    )

In [None]:
d = strain_gene_uhgg.astype(bool).join(uhgg_x_eggnog.set_index('gene_id').eggnog).groupby('eggnog').any()
x = d[ref_list].mean(1)
y = d[spgc_list].mean(1)

print(sp.stats.pearsonr(x, y))

fig, axs = plt.subplots(2, figsize=(5, 10))

bins0 = np.linspace(0.0, 1.0, num=50)
axs[0].hist2d(x, y, bins=bins0, norm=mpl.colors.PowerNorm(1 / 3, vmin=0, vmax=1e3))

bins1 = np.linspace(0.1, 0.9, num=40)
axs[1].hist2d(x, y, bins=bins1, norm=mpl.colors.PowerNorm(1 / 3))
axs[1].set_xlabel("reference prevalence")
axs[1].set_ylabel("inferred prevalence")
None

In [None]:
_strain_linkage = fastcluster.linkage(d[ref_list + spgc_list].T, metric='jaccard', method="average")

In [None]:
x = d[ref_list + spgc_list]
_col_linkage = _strain_linkage
# _row_linkage = _gene_linkage
# Order x by leaf order.
# See <https://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.hierarchy.ClusterNode.pre_order.html#scipy.cluster.hierarchy.ClusterNode.pre_order>
# x = x.iloc[sp.cluster.hierarchy.to_tree(_row_linkage).pre_order(lambda x: x.id)]

_col_colors = genome_type.map(genome_type_palette)

if show_unimportant_figures:
    sns.clustermap(
        x, row_cluster=False, col_linkage=_col_linkage, col_colors=_col_colors
    )

### Gene Content Divergence from Refs

Are SPGC Genomes Believably Distinct?

Test if SPGC Strains are surprisingly distinct from refs in gene content relative to their distinction in genotype

#### Min-genotype / Min-gene content

#### Min-genotype / Matched-gene content

In [None]:
_pdistA = strain_geno_pdist
_pdistB = strain_uhgg_filt_unweighted_jacc_pdist

# gene_pdist_adjust = (1 / strain_gene_filt_eggnog.shape[0])


# Remove the diagonal from "minimum distance".
_pdistA = _pdistA + np.eye(len(_pdistA))
_pdistB = _pdistB + np.eye(len(_pdistB))

isolate_to_ref_min = _pdistA.loc[ref_list, isolate_list].idxmin()
isolate_to_ref_min_pdistA = _pdistA.stack()[
    isolate_to_ref_min.reset_index().set_index(["index", 0]).index
].reset_index(level=1, drop=True)
isolate_to_ref_min_pdistB = _pdistB.stack()[
    isolate_to_ref_min.reset_index().set_index(["index", 0]).index
].reset_index(level=1, drop=True)

mag_to_ref_min = _pdistA.loc[ref_list, mag_list].idxmin()
mag_to_ref_min_pdistA = _pdistA.stack()[
    mag_to_ref_min.reset_index().set_index(["index", 0]).index
].reset_index(level=1, drop=True)
mag_to_ref_min_pdistB = _pdistB.stack()[
    mag_to_ref_min.reset_index().set_index(["index", 0]).index
].reset_index(level=1, drop=True)

spgc_to_ref_min = _pdistA.loc[ref_list, spgc_list].idxmin()
spgc_to_ref_min_pdistA = _pdistA.stack()[
    spgc_to_ref_min.reset_index().set_index(["index", 0]).index
].reset_index(level=1, drop=True)
spgc_to_ref_min_pdistB = _pdistB.stack()[
    spgc_to_ref_min.reset_index().set_index(["index", 0]).index
].reset_index(level=1, drop=True)

d0 = pd.DataFrame(
    dict(
        min_geno_ref=pd.concat(
            [
                isolate_to_ref_min,
                mag_to_ref_min,
                spgc_to_ref_min,
            ]
        ),
        geno_dist=(
            pd.concat(
                [
                    isolate_to_ref_min_pdistA,
                    mag_to_ref_min_pdistA,
                    spgc_to_ref_min_pdistA,
                ]
            )
        ),
        gene_dist=pd.concat(
            [isolate_to_ref_min_pdistB, mag_to_ref_min_pdistB, spgc_to_ref_min_pdistB]
        ),
        genome_type=["Isolate"] * len(isolate_to_ref_min_pdistA)
        + ["MAG"] * len(mag_to_ref_min_pdistA)
        + ["SPGC"] * len(spgc_to_ref_min_pdistA),
    )
)  # .assign(
#     geno_dist_pc=lambda x: x.geno_dist + geno_pdist_adjust,
#     gene_dist_pc=lambda x: x.gene_dist + gene_pdist_adjust,
# )
fit = smf.ols(f"gene_dist ~ np.log2(geno_dist) * genome_type", data=d0).fit()
genome_distance_comparison_filt_unweighted_jacc = d0.assign(
    gene_dist_predict=lambda x: fit.predict(), gene_dist_resid_pearson=fit.resid_pearson
).sort_values("geno_dist")

fig, axs = plt.subplots(1, 2, figsize=(10, 4))

ax = axs[0]
for _genome_type, d2 in genome_distance_comparison_filt_unweighted_jacc.groupby(
    "genome_type"
):
    ax.scatter("geno_dist", "gene_dist", data=d2, label=_genome_type)
    ax.plot("geno_dist", "gene_dist_predict", data=d2, label="__nolegend__")

ax = axs[1]
for _genome_type, d2 in genome_distance_comparison_filt_unweighted_jacc.groupby(
    "genome_type"
):
    ax.scatter("geno_dist", "gene_dist", data=d2, label=_genome_type)
    ax.plot("geno_dist", "gene_dist_predict", data=d2, label="__nolegend__")
ax.legend()
ax.set_xscale("symlog", linthresh=1e-4)
fit.summary()

In [None]:
genome_type_to_label = {"isolate": "Isolate", "mag": "MAG", "spgc": "SPGC"}

fig, ax = plt.subplots(figsize=(5, 5))
for _genome_type, d2 in genome_distance_comparison_filt_unweighted_jacc.groupby(
    "genome_type"
):
    ax.scatter(
        "geno_dist",
        "gene_dist",
        data=d2,
        color=genome_type_palette[_genome_type],
        label=_genome_type,
        s=50,
        alpha=0.7,
        lw=0,
    )
    ax.plot(
        "geno_dist",
        "gene_dist_predict",
        data=d2,
        color=genome_type_palette[_genome_type],
        label="__nolegend__",
    )
# ax.legend(markerscale=3)
ax.set_xscale("symlog", linthresh=1e-4)
ax.set_xlabel("SNP Profile Dissimilarity")
ax.set_ylabel("Gene Content Dissimilarity")
# ax.set_yticks([0, 0.1, 0.2, 0.3, 0.4])

In [None]:
# TODO: Consider using a permutation test to say whether there is an overall relationship between
# genotype distance and gene content distance and whether this is different in SPGC strains compared
# to the references. (Similar to the EEN strain turnover analysis.)

## Strain Diversity Analysis

In [None]:
x = strain_geno_pdist.copy()
dii, djj = np.diag_indices_from(x)
x.values[dii, djj] = np.nan

_colors = genome_type.map(genome_type_palette)

g = sns.clustermap(
    x,
    row_colors=_colors,
    col_colors=_colors,
    row_linkage=strain_uhgg_filt_unweighted_jacc_linkage,
    col_linkage=strain_geno_linkage,
    # figsize=(40, 40),
)
g.ax_heatmap.set_facecolor("aqua")
# g.cax.set_visible(False)

In [None]:
x = strain_geno_pdist.copy()
dii, djj = np.diag_indices_from(x)
x.values[dii, djj] = np.nan

_colors = pd.DataFrame(
    dict(
        g=genome_type.map(genome_type_palette),
    )
)

g = sns.clustermap(
    x,
    row_colors=_colors,
    col_colors=_colors,
    row_linkage=strain_uhgg_filt_unweighted_jacc_linkage,
    col_linkage=strain_geno_linkage,
    xticklabels=0,
    yticklabels=0,
    dendrogram_ratio=0.1,
    vmin=0,
    vmax=1,
    tree_kws=dict(lw=1.0),
    # figsize=(40, 40),
)
g.ax_heatmap.set_facecolor("aqua")
g.cax.set_visible(False)

In [None]:
x = strain_uhgg_filt_unweighted_jacc_pdist.copy()
dii, djj = np.diag_indices_from(x)
x.values[dii, djj] = np.nan

_colors = pd.DataFrame(
    dict(
        g=genome_type.map(genome_type_palette),
    )
)

g = sns.clustermap(
    x,
    row_colors=_colors,
    col_colors=_colors,
    row_linkage=strain_uhgg_filt_unweighted_jacc_linkage,
    col_linkage=strain_geno_linkage,
    xticklabels=0,
    yticklabels=0,
    dendrogram_ratio=0.1,
    vmin=0,
    vmax=1,
    tree_kws=dict(lw=1.0),
    # figsize=(40, 40),
)
g.ax_heatmap.set_facecolor("aqua")
g.cax.set_visible(False)
# # Clip the "tips" of the dendrogram so that it's not just a black rectangle (for dense maps).
# g.ax_col_dendrogram.set_ylim(0.1, 1)
# g.ax_row_dendrogram.set_xlim(1, 0.1)

### Prevalence Comparisons

In [None]:
x = ref_uhgg_prevalence
y = spgc_uhgg_prevalence

print(sp.stats.pearsonr(x, y))

fig, axs = plt.subplots(2, figsize=(5, 10))

bins0 = np.linspace(0.0, 1.0, num=50)
axs[0].hist2d(x, y, bins=bins0, norm=mpl.colors.PowerNorm(1 / 3, vmin=0, vmax=1e3))

bins1 = np.linspace(0.1, 0.9, num=40)
axs[1].hist2d(x, y, bins=bins1, norm=mpl.colors.PowerNorm(1 / 3))
axs[1].set_xlabel("reference prevalence")
axs[1].set_ylabel("inferred prevalence")
None

### Core / Shell / Cloud Pangenome

In [None]:
def _assign_prevalence_class(p):
    if p > 0.95:
        return "core"
    elif p > 0.1:
        return "shell"
    elif p < 0.1:
        return "cloud"


prevalence_class_order = ["core", "shell", "cloud"]
prevalence_class_palette = {
    "core": "tab:blue",
    "shell": "tab:orange",
    "cloud": "tab:green",
}

In [None]:
bins = np.linspace(0, 1, num=51)

spgc_uhgg_class = spgc_uhgg_prevalence.map(_assign_prevalence_class)

for prevalence_class in prevalence_class_order:
    plt.hist(
        spgc_uhgg_prevalence[spgc_uhgg_class == prevalence_class],
        bins=bins,
        label=prevalence_class,
        color=prevalence_class_palette[prevalence_class],
    )
plt.legend()
plt.yscale("log")
# core_genes = idxwhere(strain_uhgg_filt_prevalence > 0.9)
# shell_genes = idxwhere((strain_uhgg_filt_prevalence < 0.9) & (strain_uhgg_filt_prevalence > 0.1))
# cloud_genes = idxwhere(strain_uhgg_filt_prevalence < 0.1)

In [None]:
bins = np.linspace(0, 1, num=51)

ref_uhgg_class = ref_uhgg_prevalence.map(_assign_prevalence_class)

for prevalence_class in prevalence_class_order:
    plt.hist(
        ref_uhgg_prevalence[ref_uhgg_class == prevalence_class],
        bins=bins,
        label=prevalence_class,
        color=prevalence_class_palette[prevalence_class],
    )
plt.legend()
plt.yscale("log")
# core_genes = idxwhere(strain_uhgg_filt_prevalence > 0.9)
# shell_genes = idxwhere((strain_uhgg_filt_prevalence < 0.9) & (strain_uhgg_filt_prevalence > 0.1))
# cloud_genes = idxwhere(strain_uhgg_filt_prevalence < 0.1)

In [None]:
# Prevalence classes defined on SPGC strains.
strain_genome_spgc_prevalence_class_tally = (
    strain_uhgg_filt.groupby(spgc_uhgg_class)
    .sum()
    .reindex(prevalence_class_order, fill_value=0)
    .T.astype(int)
)

fig = plt.figure(figsize=(10, 5), tight_layout=True)

gs = gridspec.GridSpec(
    1, 2, width_ratios=[len(ref_list) / len(spgc_list), 1], figure=fig
)
ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)

# , axs = plt.subplots(1, 2, figsize=(14, 5), sharey=True)
for ax, strain_list, title in zip(
    [ax0, ax1], [ref_list, spgc_list], ["references", "spgc"]
):
    # TODO: Fix the colors to always match the previous palette.
    strain_genome_spgc_prevalence_class_tally.loc[strain_list].sort_values(
        "shell"
    ).plot.bar(stacked=True, ax=ax, width=1)
    ax.set_title(title)
    ax.set_xticks([])
    ax.set_aspect(0.08)
    ax.legend_.set_visible(False)
ax1.legend_.set_visible(True)
ax1.legend(bbox_to_anchor=(1, 1))

In [None]:
# Prevalence classes defined on ref strains.
strain_genome_ref_prevalence_class_tally = (
    strain_uhgg_filt.groupby(ref_uhgg_class)
    .sum()
    .reindex(prevalence_class_order, fill_value=0)
    .T.astype(int)
)

fig = plt.figure(figsize=(10, 5), tight_layout=True)

gs = gridspec.GridSpec(
    1, 2, width_ratios=[len(ref_list) / len(spgc_list), 1], figure=fig
)
ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)

# , axs = plt.subplots(1, 2, figsize=(14, 5), sharey=True)
for ax, strain_list, title in zip(
    [ax0, ax1], [ref_list, spgc_list], ["references", "spgc"]
):
    # TODO: Fix the colors to always match the previous palette.
    strain_genome_ref_prevalence_class_tally.loc[strain_list].sort_values(
        "shell"
    ).plot.bar(stacked=True, ax=ax, width=1)
    ax.set_title(title)
    ax.set_xticks([])
    ax.set_aspect(0.08)
    ax.legend_.set_visible(False)
ax1.legend_.set_visible(True)
ax1.legend(bbox_to_anchor=(1, 1))

In [None]:
gene_annotations.groupby(spgc_uhgg_class).centroid_99_length.quantile(
    [0.1, 0.25, 0.5, 0.75, 0.9]
).unstack()

In [None]:
cog_category_order = gene_x_cog_category_matrix.columns
cog_category_palette = lib.plot.construct_ordered_palette(
    cog_category_order, cm="rainbow", extend=dict(no_category="grey")
)
for cog_category in cog_category_order:
    plt.scatter([], [], c=cog_category_palette[cog_category], label=cog_category)
plt.legend(ncols=4)

lib.plot.hide_axes_and_spines()

In [None]:
cog_category_description = pd.read_table(
    "ref/cog-20.categories.tsv",
    names=["cog_category", "description"],
    index_col="cog_category",
).assign(description=lambda x: x.index + ": " + x.description)
cog_category_description.loc["no_category", "description"] = "-: No Annotation"
cog_category_description

In [None]:
x = spgc_uhgg_class
y = gene_x_cog_category_matrix.reindex(
    x.index, fill_value=False
)  # .assign(no_category=lambda x: x.sum(1) == 0)
gene_list = spgc_uhgg_class.index

cog_category_gene_class_enrichment_test = []

for _prevalence_class, _cog_category in product(
    prevalence_class_order, list(cog_category_order)
):
    contingency_table = (
        pd.DataFrame(
            dict(
                is_prev_class=(x == _prevalence_class),
                is_cog_category=y[_cog_category],
            )
        )
        .value_counts()
        .unstack()
        .reindex(index=[False, True], columns=[False, True])
        .fillna(0)
    )
    _test = sp.stats.fisher_exact(contingency_table)
    cog_category_gene_class_enrichment_test.append(
        (
            _prevalence_class,
            _cog_category,
            _test[0],
            _test[1],
            contingency_table.loc[True, True],
        )
    )

cog_category_gene_class_enrichment_test = (
    pd.DataFrame(
        cog_category_gene_class_enrichment_test,
        columns=[
            "prevalence_class",
            "cog_category",
            "statistic",
            "pvalue",
            "gene_count",
        ],
    )
    .set_index(["prevalence_class", "cog_category"])
    .assign(
        negative_log10_pvalue=lambda x: -np.log10(x.pvalue),
        log2_odds_ratio=lambda x: np.log2(x.statistic),
    )
)

In [None]:
def _assign_significance_marker(pvalue):
    if pvalue < 1e-5:
        return "***"
    elif pvalue < 1e-3:
        return "**"
    elif pvalue < 0.05:
        return "*"
    else:
        return ""

In [None]:
x = (
    cog_category_gene_class_enrichment_test.log2_odds_ratio.unstack("prevalence_class")
    .replace({np.inf: np.nan, -np.inf: np.nan})
    .join(cog_category_description)
    .set_index("description")[prevalence_class_order]
    .fillna(0)
)
# annot = (cog_category_gene_class_enrichment_test.pvalue.map(_assign_significance_marker) + '|' + cog_category_gene_class_enrichment_test.gene_count.astype(int).astype(str)).unstack('prevalence_class')[prevalence_class_order]
annot = (
    cog_category_gene_class_enrichment_test.pvalue.map(_assign_significance_marker)
    .unstack("prevalence_class")
    .join(cog_category_description)
    .set_index("description")[prevalence_class_order]
)
# annot = cog_category_gene_class_enrichment_test.gene_count.unstack('prevalence_class')[prevalence_class_order].astype(int)

_row_order = x["core"].sort_values(ascending=False).index
# x, annot = lib.pandas_util.align_indexes(x, annot)

fig, ax = plt.subplots(figsize=(5, 12))
ax = sns.heatmap(
    x.reindex(_row_order),
    annot=annot.reindex(_row_order),
    fmt="",
    cmap="coolwarm",
    center=0,
    vmin=-5,
    vmax=5,
    cbar_kws=dict(
        use_gridspec=True, location="left", label="log2(odds ratio)", extend="both"
    ),
    ax=ax,
    yticklabels=1,
    xticklabels=1,
    annot_kws=dict(va="center"),
    # norm=mpl.colors.SymLogNorm(linthresh=1e1),
    # center=0,
)

ax.yaxis.set_ticks_position("right")
ax.set_ylabel("")
lib.plot.rotate_yticklabels(rotation=-0, va="center")

In [None]:
fig, ax = plt.subplots()

for prevalence_class in prevalence_class_order:
    d = cog_category_gene_class_enrichment_test.xs(
        prevalence_class, level="prevalence_class"
    )
    ax.scatter(
        "log2_odds_ratio",
        "negative_log10_pvalue",
        data=d,
        c=prevalence_class_palette[prevalence_class],
        label=prevalence_class,
        s=5,
    )

# Annotate top-10
for idx in (
    cog_category_gene_class_enrichment_test.sort_values(
        "negative_log10_pvalue", ascending=False
    )
    .head(10)
    .index
):
    ax.annotate(
        idx[1],
        xy=cog_category_gene_class_enrichment_test.loc[idx][
            ["log2_odds_ratio", "negative_log10_pvalue"]
        ],
        xytext=(2, 1),
        textcoords="offset pixels",
        color=prevalence_class_palette[idx[0]],
        # ha='center',
        # va='center',
        fontweight="bold",
    )
ax.axvline(0, linestyle="--", lw=1, color="k")
ax.axhline(
    -np.log10(0.05 / cog_category_gene_class_enrichment_test.shape[0]),
    linestyle="--",
    lw=1,
    color="k",
)
ax.legend()
# cog_category_gene_class_enrichment_test.sort_values('negative_log10_pvalue', ascending=False).head(20)

### Strength of Phylogenetic Signal in Shell Genes

In [None]:
shell_gene_filt_list = list(
    set(idxwhere((spgc_uhgg_class == "shell"))) & set(strain_uhgg_filt.index)
)

x = strain_gene_uhgg.loc[shell_gene_filt_list]

# # Weighted
# strain_shell_gene_cdist = sp.spatial.distance.pdist(
#     x.T, metric="cosine", w=ref_gene_uhgg_entropy.loc[x.index]
# )

# Unweighted
strain_shell_gene_cdist = sp.spatial.distance.pdist(
    x.T,
    metric="jaccard",
)
strain_shell_gene_pdist = pd.DataFrame(
    squareform(strain_shell_gene_cdist), index=x.columns, columns=x.columns
)
strain_shell_gene_linkage = sp.cluster.hierarchy.optimal_leaf_ordering(
    fastcluster.linkage(strain_shell_gene_cdist, method="average"),
    strain_shell_gene_cdist,
)

# _col_linkage = strain_shell_gene_linkage
_col_colors = genome_type.map(genome_type_palette)
_col_linkage = strain_shell_gene_linkage

if show_unimportant_figures:
    sns.clustermap(x, col_linkage=_col_linkage, col_colors=_col_colors)

In [None]:
# FIXME: x is ambiguous here.
x = strain_shell_gene_pdist.copy()
dii, djj = np.diag_indices_from(x)
x.values[dii, djj] = np.nan

_col_linkage = strain_geno_linkage
_row_linkage = strain_shell_gene_linkage
_colors = genome_type.map(genome_type_palette)

g = sns.clustermap(
    x,
    col_linkage=_col_linkage,
    row_linkage=_row_linkage,
    col_colors=_colors,
    row_colors=_colors,
)
g.ax_heatmap.set_facecolor("aqua")

In [None]:
_genome_list = strain_geno_pdist.index
_results = {}
w = geno_pdist_to_weights(strain_geno_pdist.loc[_genome_list, _genome_list], rate=1)

d = pd.DataFrame(w, index=_genome_list, columns=_genome_list)
sns.clustermap(d, col_linkage=strain_geno_linkage, row_linkage=strain_geno_linkage)

In [None]:
# First, check if there's any "phylogenetic signal" of being SPGC vs. other sources of strains.

with open(geno_pdist_inpath, "rb") as f:
    d = pickle.load(f)

_cdmat = d["cdmat"]
_labels = d['labels']
_pdmat = pd.DataFrame(squareform(_cdmat), index=_labels, columns=_labels)

w = pd.DataFrame(
    geno_pdist_to_weights(_pdmat, rate=1),
    index=_pdmat.index,
    columns=_pdmat.columns,
)

x = pd.concat([reference_meta.Genome_type, spgc_meta.assign(Genome_type='SPGC').Genome_type]).isin(['SPGC'])
w = w.loc[x.index, x.index]

morans_i(
        x.values,
        w.values,
        return_zscore=True,
        centered=True,
)

In [None]:
# This is the slowest running cell, because it can have 10s or 100s of thousands of genes for some species.

_gene_list = strain_gene_uhgg.index
# TODO: Pick best weight function
w = pd.DataFrame(
    geno_pdist_to_weights(strain_geno_pdist, rate=1),
    index=strain_geno_pdist.index,
    columns=strain_geno_pdist.columns,
)

_genome_list = ref_list
_results = {}
_w = w.loc[_genome_list, _genome_list].values
for gene_id in tqdm(_gene_list):
    _results[gene_id] = morans_i(
        strain_gene_uhgg.loc[gene_id, _genome_list].values,
        _w,
        return_zscore=False,
        centered=True,
    )
phylogenetic_signal_ref = pd.Series(_results)

_genome_list = spgc_list
_results = {}
_w = w.loc[_genome_list, _genome_list].values
for gene_id in tqdm(_gene_list):
    _results[gene_id] = morans_i(
        strain_gene_uhgg.loc[gene_id, _genome_list].values,
        _w,
        return_zscore=False,
        centered=True,
    )
phylogenetic_signal_spgc = pd.Series(_results)

In [None]:
_gene_list = (spgc_uhgg_prevalence > 0.1) & (
    spgc_uhgg_prevalence < 0.9
)  # phylogenetic_signal_ref.index  # shell_gene_filt_list

d = (
    pd.DataFrame(dict(ref=phylogenetic_signal_ref, spgc=phylogenetic_signal_spgc))
    .loc[_gene_list]
    .fillna(0)
)


bins = 200  # np.linspace(-100, 100)

fig, ax = plt.subplots(figsize=(12, 10))
if not d.empty:
    print(len(d))
    print(sp.stats.pearsonr(d.ref, d.spgc))

    plt.hist2d(
        x="ref",
        y="spgc",
        data=d,
        # kind="reg",
        bins=(bins, bins),
        norm=mpl.colors.PowerNorm(1 / 5),
    )
    plt.colorbar()
    None
# plt.xscale('symlog', linthresh=1e-4, linscale=0.05)
# plt.yscale('symlog', linthresh=1e-4, linscale=0.05)
# plt.xlim(-5e-3, 1e-1)
# plt.ylim(-5e-3, 1e-1)

In [None]:
_gene_list = shell_gene_filt_list

bins = (
    list(reversed(-np.logspace(-4, -0, num=50)))
    + [0]
    + list(np.logspace(-4, -0, num=50))
)

plt.hist(phylogenetic_signal_spgc.loc[_gene_list], alpha=0.5, bins=bins)
plt.hist(phylogenetic_signal_ref.loc[_gene_list], alpha=0.5, bins=bins)

plt.xscale("symlog", linthresh=1e-4, linscale=0.05)
plt.yscale("symlog", linthresh=1e-4, linscale=0.05)

In [None]:
_gene_list = (spgc_uhgg_prevalence > 0.1) & (
    spgc_uhgg_prevalence < 0.9
)  # phylogenetic_signal_ref.index  # shell_gene_filt_list

d = (
    pd.DataFrame(dict(ref=phylogenetic_signal_ref, spgc=phylogenetic_signal_spgc))
    .loc[_gene_list]
    .fillna(0)
)


bins = (
    list(reversed(-np.logspace(-4, -0, num=50)))
    + [0]
    + list(np.logspace(-4, -0, num=50))
)

fig, ax = plt.subplots(figsize=(12, 10))
if not d.empty:
    print(len(d))
    print(sp.stats.pearsonr(d.ref, d.spgc))

    plt.hist2d(
        x="ref",
        y="spgc",
        data=d,
        # kind="reg",
        bins=(bins, bins),
        norm=mpl.colors.PowerNorm(1 / 5),
    )
    plt.colorbar()
    None
plt.xscale("symlog", linthresh=1e-4, linscale=0.05)
plt.yscale("symlog", linthresh=1e-4, linscale=0.05)
plt.xlim(-5e-3, 1e-1)
plt.ylim(-5e-3, 1e-1)

In [None]:
bins = (
    list(reversed(-np.logspace(-5, 0, num=50))) + [0] + list(np.logspace(-5, 0, num=50))
)
_gene_list = shell_gene_filt_list  # idxwhere((ref_uhgg_prevalence > 0.1) & (ref_uhgg_prevalence < 0.9))
# plt.plot(bins)
plt.hist(phylogenetic_signal_spgc.loc[_gene_list].dropna(), bins=bins, alpha=0.5)
plt.hist(phylogenetic_signal_ref.loc[_gene_list].dropna(), bins=bins, alpha=0.5)

plt.xscale("symlog", linthresh=1e-5, linscale=0.05)
plt.yscale("log")

In [None]:
d = phylogenetic_signal_ref.to_frame("morans_i_centered").assign(
    prevalence=ref_uhgg_prevalence
)
plt.scatter(
    "prevalence",
    "morans_i_centered",
    data=d,
    # c="zscore",
    norm=mpl.colors.SymLogNorm(linthresh=0.1),
    s=5,
    alpha=1.0,
)
plt.colorbar()

In [None]:
x = strain_gene_uhgg.loc[shell_gene_filt_list]  # , spgc_list]

min_spgc_morans_i = phylogenetic_signal_spgc.min()
max_spgc_morans_i = phylogenetic_signal_spgc.max()
_cmap_spgc = lambda x: mpl.cm.coolwarm(
    (x - min_spgc_morans_i) / (max_spgc_morans_i - min_spgc_morans_i)
)

min_ref_morans_i = phylogenetic_signal_ref.min()
max_ref_morans_i = phylogenetic_signal_ref.max()
_cmap_ref = lambda x: mpl.cm.coolwarm(
    (x - min_ref_morans_i) / (max_ref_morans_i - min_ref_morans_i)
)

_col_colors = genome_type.map(genome_type_palette)
_row_colors = pd.DataFrame(
    dict(
        phylo_spgc=phylogenetic_signal_spgc.map(_cmap_spgc),
        phylo_ref=phylogenetic_signal_ref.map(_cmap_ref),
        entrp=ref_gene_uhgg_entropy.map(mpl.cm.viridis),
    )
)

if show_unimportant_figures:
    sns.clustermap(
        x,
        row_colors=_row_colors,
        col_linkage=strain_geno_linkage,
        col_colors=_col_colors,
    )

In [None]:
cog_category_phylogenetic_signal_test = {}

for _cog_category in tqdm(list(cog_category_order)):
    d = pd.DataFrame(
        dict(
            phylogenetic_signal=phylogenetic_signal_spgc,
            is_cog_category=gene_x_cog_category_matrix.reindex(
                phylogenetic_signal_spgc.index
            ).fillna(False)[_cog_category],
        )
    ).loc[shell_gene_filt_list]
    x = d[d.is_cog_category].phylogenetic_signal
    y = d[~d.is_cog_category].phylogenetic_signal
    if (len(x) > 5) and (len(y) > 5):
        _test = sp.stats.mannwhitneyu(x, y)
    else:
        _test = (np.nan, np.nan)
    cog_category_phylogenetic_signal_test[_cog_category] = (
        len(x),
        x.median(),
        y.median(),
        x.mean(),
        y.mean(),
        *_test,
    )

cog_category_phylogenetic_signal_test = pd.DataFrame(
    cog_category_phylogenetic_signal_test,
    index=[
        "num_genes_in_category",
        "cog_median_i",
        "not_cog_median_i",
        "cog_mean_i",
        "not_cog_mean_i",
        "mwu_statistic",
        "pvalue",
    ],
).T.assign(
    negative_log10_pvalue=lambda x: -np.log10(x.pvalue),
    median_diff=lambda x: x.cog_median_i - x.not_cog_median_i,
)

In [None]:
d = cog_category_phylogenetic_signal_test.join(cog_category_description).dropna(
    subset=["median_diff"]
)

fig, ax = plt.subplots()
plt.scatter(
    "median_diff",
    "pvalue",
    data=d,
)
# lib.plot.rotate_xticklabels()
ax.axvline(0, linestyle="--", lw=1, color="k")
# ax.axhline(-np.log10(0.05 / d.shape[0]), linestyle='--', lw=1, color='k')


for cog_category in idxwhere((np.abs(d.median_diff) > 0) & (d.pvalue < 1e-4)):
    plt.annotate(
        cog_category,
        xy=d[["median_diff", "pvalue"]].loc[cog_category],
        xytext=(4, 2),
        textcoords="offset points",
        # color=prevalence_class_palette[prevalence_class],
        # ha='center',
        # va='center',
        # fontweight='bold',
    )

plt.yscale("log")
plt.xlim(-0.04, 0.04)
plt.ylim(1e-15, 1)
plt.gca().invert_yaxis()
plt.xlabel("Difference in Median Moran's I\n(in COG category vs. not)")
plt.ylabel("P-value")

d.sort_values("median_diff", ascending=False)

In [None]:
x = (
    cog_category_phylogenetic_signal_test.join(cog_category_description)[
        ["median_diff"]
    ]
    .dropna()
    .join(cog_category_description)
    .set_index("description")
)

# Order rows based on the COG-category enrichment analysis
_row_order = (
    cog_category_gene_class_enrichment_test.log2_odds_ratio.unstack("prevalence_class")
    .replace({np.inf: np.nan, -np.inf: np.nan})
    .join(cog_category_description)
    .set_index("description")[prevalence_class_order]
    .fillna(0)["core"]
    .sort_values(ascending=False)
    .index
)
annot = (
    cog_category_phylogenetic_signal_test.pvalue.map(_assign_significance_marker)
    .to_frame()
    .join(cog_category_description)
    .set_index("description")
)


fig, ax = plt.subplots(figsize=(3.5, 12))
ax = sns.heatmap(
    x.reindex(_row_order),
    annot=annot.reindex(_row_order),
    fmt="",
    cmap="PuOr",
    center=0,  # phylogenetic_signal_spgc.mean(),
    cbar_kws=dict(
        use_gridspec=True,
        location="left",
        label="Difference in Median Moran's I",
        extend="both",
        fraction=0.5,
    ),
    # ax=ax,
    # yticklabels=1,
    # xticklabels=1,
    # annot_kws=dict(va='center'),
    # norm=mpl.colors.SymLogNorm(linthresh=1e1),
)

ax.yaxis.set_ticks_position("right")
ax.set_ylabel("")
lib.plot.rotate_yticklabels(rotation=-0, va="center")

In [None]:
# d = phylogenetic_signal_spgc.loc[shell_gene_filt_list].dropna().to_frame().join(gene_x_cog_category).fillna("-")
d = (
    phylogenetic_signal_spgc.dropna()
    .to_frame("morans_i_centered")
    .join(gene_x_cog_category)
    .fillna("-")
)


fig, ax = plt.subplots(figsize=(20, 5))
sns.violinplot(
    data=d,
    x="cog_category",
    y="morans_i_centered",
    ax=ax,
    order=cog_category_phylogenetic_signal_test.cog_median_i.sort_values().index,
)

## Gene Co-occurence

### Found using SPGC only

In [None]:
gene_prevalence = strain_uhgg_filt[spgc_list].mean(1)
plt.hist(gene_prevalence)
plt.yscale("log")
variable_genes = idxwhere((gene_prevalence > 0.05) & (gene_prevalence < 0.95))
len(variable_genes)

In [None]:
x = strain_gene_uhgg.astype(bool)
y = x[spgc_list]

drop_nohit_genes_list = idxwhere(y.sum(1) == 0)
drop_ubiq_genes_list = idxwhere((~y).sum(1) == 0)
drop_single_hit_genes_list = idxwhere(y.sum(1) == 1)
drop_only_one_missing_genes_list = idxwhere((~y).sum(1) == 1)
print(
    len(drop_nohit_genes_list),
    len(drop_ubiq_genes_list),
    len(drop_single_hit_genes_list),
    len(drop_only_one_missing_genes_list),
)

z = y.drop(
    drop_nohit_genes_list
    + drop_ubiq_genes_list
    + drop_single_hit_genes_list
    + drop_only_one_missing_genes_list
)
print(z.shape)

_col_colors = pd.DataFrame(
    dict(
        genome_type=genome_type.map(genome_type_palette),
    )
)

if not z.empty:
    linkage1 = fastcluster.linkage(z, metric="cosine", method="average")
    # # Alternative: Optimal leaf ordering, but potentially MUCH slower.
    # cdmat1 = pdist(z, metric="cosine")
    # linkage1 = sp.cluster.hierarchy.optimal_leaf_ordering(fastcluster.linkage(cdmat1, method="average"), cdmat1)

    clust1 = pd.Series(
        sp.cluster.hierarchy.fcluster(
            linkage1,
            t=0.01,
            criterion="distance",
        ),
        index=z.index,
    )
    print("DONE: Clustering")
else:
    clust1 = pd.Series(np.arange(len(z)), index=z.index)

# Add back ubiquitous and nohit genes as clusters.
clust1 = pd.concat(
    [
        clust1,
        pd.Series(-1, index=drop_nohit_genes_list),
        pd.Series(-2, index=drop_ubiq_genes_list),
        pd.Series(-3, index=drop_single_hit_genes_list),
        pd.Series(-4, index=drop_only_one_missing_genes_list),
    ]
)
clust1_palette = lib.plot.construct_ordered_palette(clust1.unique())

if show_unimportant_figures:
    sns.clustermap(
        x.loc[z.index],
        row_colors=z.index.to_series().map(clust1).map(clust1_palette),
        col_linkage=strain_uhgg_filt_unweighted_jacc_linkage,
        row_linkage=linkage1,
        col_colors=_col_colors,
    )

In [None]:
clust1_sizes = clust1.value_counts()

clust1_sizes_meta = (
    clust1.value_counts().to_frame("tally")
    # .groupby('clust')
    # .count()
    # .sort_values(ascending=False)
    .assign(
        top5_cc=gene_x_cog_category_matrix.groupby(clust1)
        .sum()
        .rename(int)
        .astype(int)
        .apply(lambda x: x.sort_values(ascending=False).head(5).index.to_list(), axis=1)
    )
)

clust1_sizes_meta.head(20)

In [None]:
_real_clusters = clust1_sizes_meta.drop([-1, -2, -3, -4], errors="ignore")
if _real_clusters.empty:
    print("No real clusters.")
else:
    _clust = clust1_sizes_meta.drop([-1, -2, -3, -4], errors="ignore").index[0]

    _gene_list = idxwhere(clust1 == _clust)
    x = strain_gene_uhgg.loc[_gene_list]
    y = strain_gene_uhgg_depth.loc[_gene_list]
    _col_colors = pd.DataFrame(
        dict(
            genome_type=genome_type.map(genome_type_palette),
        )
    )
    _row_colors = (
        gene_annotations.reindex(_gene_list)
        .centroid_99_length.map(np.log10)
        .map(lambda x: x / 4)
        .map(mpl.cm.viridis)
    )
    _col_linkage = strain_geno_linkage
    _row_linkage = sp.cluster.hierarchy.linkage(
        x, method="average", metric="cosine", optimal_ordering=True
    )  # TODO

    sns.clustermap(
        x,
        figsize=(10, 5),
        row_linkage=_row_linkage,
        col_linkage=_col_linkage,
        col_colors=_col_colors,
        xticklabels=0,
        yticklabels=0,
    )
    sns.clustermap(
        y,
        figsize=(10, 5),
        row_linkage=_row_linkage,
        col_linkage=_col_linkage,
        col_colors=_col_colors,
        norm=mpl.colors.PowerNorm(1 / 2, vmin=0, vmax=2),
        xticklabels=0,
        yticklabels=0,
    )

    print(phylogenetic_signal_spgc.loc[_gene_list].median())
    print(phylogenetic_signal_ref.loc[_gene_list].median())

    print(
        gene_x_cog_category_matrix.reindex(_gene_list)
        .sum()
        .sort_values(ascending=False)[lambda x: x > 0]
        .to_frame("tally")
        .join(cog_category_description)
    )
    gene_annotations.reindex(_gene_list)

### Found using Ref only

In [None]:
x = strain_gene_uhgg.astype(bool)
y = x[ref_list]

drop_nohit_genes_list = idxwhere(y.sum(1) == 0)
drop_ubiq_genes_list = idxwhere((~y).sum(1) == 0)
drop_single_hit_genes_list = idxwhere(y.sum(1) == 1)
drop_only_one_missing_genes_list = idxwhere((~y).sum(1) == 1)
print(
    len(drop_nohit_genes_list),
    len(drop_ubiq_genes_list),
    len(drop_single_hit_genes_list),
    len(drop_only_one_missing_genes_list),
)

z = y.drop(
    drop_nohit_genes_list
    + drop_ubiq_genes_list
    + drop_single_hit_genes_list
    + drop_only_one_missing_genes_list
)
print(z.shape)

_col_colors = pd.DataFrame(
    dict(
        genome_type=genome_type.map(genome_type_palette),
    )
)

if not z.empty:
    linkage3 = fastcluster.linkage(z, metric="cosine", method="average")
    # # Alternative: Optimal leaf ordering, but potentially MUCH slower.
    # cdmat3 = pdist(z, metric="cosine")
    # linkage3 = sp.cluster.hierarchy.optimal_leaf_ordering(fastcluster.linkage(cdmat3, method="average"), cdmat3)

    clust3 = pd.Series(
        sp.cluster.hierarchy.fcluster(
            linkage3,
            t=0.035,
            criterion="distance",
        ),
        index=z.index,
    )
    print("DONE: Clustering")
else:
    clust3 = pd.Series(np.arange(len(z)), index=z.index)

# Add back ubiquitous and nohit genes as clusters.
clust3 = pd.concat(
    [
        clust3,
        pd.Series(-1, index=drop_nohit_genes_list),
        pd.Series(-2, index=drop_ubiq_genes_list),
        pd.Series(-3, index=drop_single_hit_genes_list),
        pd.Series(-4, index=drop_only_one_missing_genes_list),
    ]
)
clust3_palette = lib.plot.construct_ordered_palette(clust1.unique())

if show_unimportant_figures:
    sns.clustermap(
        x.loc[z.index],
        row_colors=z.index.to_series().map(clust3).map(clust3_palette),
        col_linkage=strain_uhgg_filt_unweighted_jacc_linkage,
        row_linkage=linkage3,
        col_colors=_col_colors,
    )

In [None]:
clust3_sizes = clust3.value_counts()

clust3_sizes_meta = (
    clust3.value_counts().to_frame("tally")
    # .groupby('clust')
    # .count()
    # .sort_values(ascending=False)
    .assign(
        top5_cc=gene_x_cog_category_matrix.groupby(clust3)
        .sum()
        .rename(int)
        .astype(int)
        .apply(lambda x: x.sort_values(ascending=False).head(5).index.to_list(), axis=1)
    )
)

clust3_sizes_meta.head(20)

In [None]:
_real_clusters = clust3_sizes_meta.drop([-1, -2, -3, -4], errors="ignore")
if _real_clusters.empty:
    print("No real clusters.")
else:
    _clust = clust3_sizes_meta.drop([-1, -2, -3, -4], errors="ignore").index[0]

_gene_list = idxwhere(clust3 == _clust)
x = strain_gene_uhgg.loc[_gene_list]
y = strain_gene_uhgg_depth.loc[_gene_list]
_col_colors = pd.DataFrame(
    dict(
        genome_type=genome_type.map(genome_type_palette),
    )
)
_row_colors = (
    gene_annotations.reindex(_gene_list)
    .centroid_99_length.map(np.log10)
    .map(lambda x: x / 4)
    .map(mpl.cm.viridis)
)
_col_linkage = strain_geno_linkage
_row_linkage = sp.cluster.hierarchy.linkage(
    x, method="average", metric="cosine", optimal_ordering=True
)  # TODO

sns.clustermap(
    x,
    figsize=(10, 5),
    row_linkage=_row_linkage,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    xticklabels=0,
    yticklabels=0,
)
sns.clustermap(
    y,
    figsize=(10, 5),
    row_linkage=_row_linkage,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    norm=mpl.colors.PowerNorm(1 / 2, vmin=0, vmax=2),
    xticklabels=0,
    yticklabels=0,
)

print(phylogenetic_signal_spgc.loc[_gene_list].median())
print(phylogenetic_signal_ref.loc[_gene_list].median())

print(
    gene_x_cog_category_matrix.reindex(_gene_list)
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
    .to_frame("tally")
    .join(cog_category_description)
)
gene_annotations.reindex(_gene_list)

### Comparison Among Clusterings

In [None]:
_gene_list = idxwhere((gene_prevalence >= 0) & (gene_prevalence <= 1))

print(adjusted_mutual_info_score(clust1[_gene_list], clust3[_gene_list]))
# print(adjusted_mutual_info_score(clust1[_gene_list], clust2[_gene_list]))
# print(adjusted_mutual_info_score(clust2[_gene_list], clust3[_gene_list]))

In [None]:
_gene_list = idxwhere((gene_prevalence >= 0.05) & (gene_prevalence <= 0.95))

print(adjusted_mutual_info_score(clust1[_gene_list], clust3[_gene_list]))
# print(adjusted_mutual_info_score(clust1[_gene_list], clust2[_gene_list]))
# print(adjusted_mutual_info_score(clust2[_gene_list], clust3[_gene_list]))

In [None]:
_gene_list = idxwhere((gene_prevalence >= 0) & (gene_prevalence <= 1))

print(adjusted_rand_score(clust1[_gene_list], clust3[_gene_list]))
# print(adjusted_rand_score(clust1[_gene_list], clust2[_gene_list]))
# print(adjusted_rand_score(clust2[_gene_list], clust3[_gene_list]))

In [None]:
_gene_list = idxwhere((gene_prevalence >= 0.05) & (gene_prevalence <= 0.95))

print(adjusted_rand_score(clust1[_gene_list], clust3[_gene_list]))
# print(adjusted_rand_score(clust1[_gene_list], clust2[_gene_list]))
# print(adjusted_rand_score(clust2[_gene_list], clust3[_gene_list]))

## SPGC-MWAS

In [None]:
(
    sfacts_fit.community.sel(strain=spgc_list)
    .to_pandas()
    .groupby(mgen_meta.subject_id)
    .mean()
    > 0.2
).any()

In [None]:
# _row_linkage = gene_uhgg_linkage
_gene_data = strain_gene_uhgg

u = (
    sfacts_fit.community.sel(strain=spgc_list)
    .to_pandas()
    .groupby(mgen_meta.subject_id)
    .mean()
    > 0.2
)[lambda x: x.any(axis=1)]
v = _gene_data[u.columns]
subject_strain_gene_content = (u @ v.T).T > 0
subject_strain_gene_prevalence = subject_strain_gene_content.mean(1)
num_strain_subjects = subject_strain_gene_content.shape[1]

subject_ibd_diagnosis = subject.loc[subject_strain_gene_content.columns].ibd_diagnosis

_col_colors = subject_ibd_diagnosis.replace(
    {"CD": "tab:green", "UC": "tab:blue", "nonIBD": "lightgrey"}
)
# row_order = sp.cluster.hierarchy.to_tree(_row_linkage).pre_order(lambda x: x.id)
# sns.clustermap(subject_strain_gene_content.iloc[row_order], col_colors=_col_colors, row_cluster=False)
if not u.empty:
    sns.clustermap(
        subject_strain_gene_content, col_colors=_col_colors, row_cluster=False
    )

In [None]:
mwas_results = []
for _gene in tqdm(strain_gene_uhgg.index):
    contingency_table0 = (
        pd.DataFrame(
            dict(
                diagnosis=subject_ibd_diagnosis,
                gene=subject_strain_gene_content.loc[_gene],
            )
        )
        .value_counts()
        .unstack("diagnosis")
        .reindex(columns=["CD", "UC", "nonIBD"])
        .fillna(0)
        .assign(IBD=lambda x: x.CD + x.UC)
    )
    for ibd_comparison, (groupA, groupB) in dict(
        ibd=("nonIBD", "IBD"), cd=("nonIBD", "CD"), uc=("nonIBD", "UC"), 
    ).items():
        contingency_table1 = contingency_table0.reindex(
            index=[True, False], columns=[groupA, groupB]
        ).fillna(0)
        _test = sp.stats.fisher_exact(contingency_table1)
        contingency_table_pc = contingency_table1 + 1
        oddsratio_pc = (
            contingency_table_pc.loc[True, groupB]
            / contingency_table_pc.loc[False, groupB]
        ) / (
            contingency_table_pc.loc[True, groupA]
            / contingency_table_pc.loc[False, groupA]
        )
        mwas_results.append((ibd_comparison, _gene, *_test, oddsratio_pc))

mwas_results = (
    pd.DataFrame(
        mwas_results,
        columns=["ibd_comparison", "gene", "statistic", "pvalue", "oddsratio_pc"],
    )
    # .join(
    #     gene_annotations[
    #         [
    #             "Description",
    #             "COG_category",
    #             "eggNOG_OGs",
    #             "centroid_99_length",
    #             "score",
    #             "Preferred_name",
    #             "KEGG_ko",
    #             "PFAMs",
    #         ]
    #     ],
    #     on="gene",
    # )
    .set_index(['gene', 'ibd_comparison'])
    .sort_values("pvalue")
)
mwas_results.head(10)

In [None]:
species_id = '101338'
d = pd.DataFrame(dict(
    in_G=mwas_stats[mwas_filt_func][lambda x: (x.species_id == species_id)].join(gene_meta).cog_categories.fillna('').str.contains('G'),
    is_signif=mwas_stats[mwas_filt_func][lambda x: (x.species_id == species_id)].fisher_exact_pvalue_ibd < 1e-3,
))
contingency = d.value_counts().unstack()
print(sp.stats.fisher_exact(contingency))
contingency

In [None]:
contingency_table0

In [None]:
bins = np.logspace(-5, 0, num=20)
# bins = np.linspace(0, 1, num=51)
for ibd_comparison in mwas_results.ibd_comparison.unique():
    plt.hist(
        mwas_results[lambda x: x.ibd_comparison == ibd_comparison].pvalue,
        label=ibd_comparison,
        alpha=0.5,
        bins=bins,
    )

plt.plot(bins[1:], (bins[1:] - bins[:-1]) * mwas_results.shape[0])
plt.legend()

plt.xscale("log")
plt.yscale("log")

In [None]:
_gene = mwas_results.gene.iloc[0]
print(_gene)
pd.DataFrame(
    dict(diagnosis=subject_ibd_diagnosis, gene=subject_strain_gene_content.loc[_gene])
).value_counts().unstack().fillna(0)

In [None]:
_gene_list = mwas_results.head(10).gene
x = strain_gene_uhgg.loc[_gene_list]
y = strain_gene_uhgg_depth.loc[_gene_list]
_col_colors = genome_type.map(genome_type_palette)
_row_colors = (
    gene_annotations.reindex(_gene_list)
    .centroid_99_length.map(np.log10)
    .map(lambda x: x / 4)
    .map(mpl.cm.viridis)
)
_col_linkage = strain_geno_linkage
_row_linkage = sp.cluster.hierarchy.linkage(
    x, method="average", metric="cosine", optimal_ordering=True
)  # TODO

sns.clustermap(
    x,
    figsize=(10, 5),
    row_linkage=_row_linkage,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    xticklabels=0,
    yticklabels=0,
)
sns.clustermap(
    y,
    figsize=(10, 5),
    row_linkage=_row_linkage,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    norm=mpl.colors.PowerNorm(1 / 2, vmin=0, vmax=2),
    xticklabels=0,
    yticklabels=0,
)

print(
    gene_x_cog_category_matrix.reindex(_gene_list)
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
    .to_frame("tally")
    .join(cog_category_description)
)
gene_annotations.reindex(_gene_list)

In [None]:
_gene_data = strain_gene_uhgg.loc[_gene_list]

_subject_x_strain = (
    sfacts_fit.community.sel(strain=spgc_list)
    .to_pandas()
    .groupby(mgen_meta.subject_id)
    .mean()
    > 0.2
)[lambda x: x.any(axis=1)]
_strain_x_gene = _gene_data[u.columns]
_subject_x_gene = (_subject_x_strain @ _strain_x_gene.T).T > 0

_col_colors = subject.loc[_subject_x_gene.columns].ibd_diagnosis.replace(
    {"CD": "tab:green", "UC": "tab:blue", "nonIBD": "lightgrey"}
)
if not _subject_x_gene.empty:
    sns.clustermap(_subject_x_gene, col_colors=_col_colors, row_cluster=False)

In [None]:
_gene_data = subject_uhgg_depth_ratio.T.loc[_gene_list].applymap(
    np.nan_to_num, nan=0, posinf=100, neginf=-100
)

_col_colors = pd.DataFrame(
    dict(
        ibd_diagnosis=subject.ibd_diagnosis.replace(
            {"CD": "tab:green", "UC": "tab:blue", "nonIBD": "lightgrey"}
        ),
        species_depth=subject_mean_species_depth.map(
            lambda x: mpl.cm.viridis(x / subject_mean_species_depth.max())
        ),
    )
)
sns.clustermap(
    _gene_data + 1e-5,
    col_colors=_col_colors,
    norm=mpl.colors.SymLogNorm(1e-3, vmin=1e-2, vmax=1e2),
    metric="cosine",
    row_cluster=False,
)

### MWAS On Genes Alone?

In [None]:
mean_subject_depth = subject_uhgg_depth_ratio[_gene]

d = subject[["ibd_diagnosis"]].join(mean_subject_depth).dropna()
for ibd_diagnosis in d.ibd_diagnosis.unique():
    plt.hist(
        d.loc[d.ibd_diagnosis == ibd_diagnosis, _gene] + 1e-4,
        bins=np.logspace(-4, 3),
        alpha=0.5,
        label=ibd_diagnosis,
    )

thresh = 1e-1
plt.axvline(thresh, lw=1, linestyle="--", color="k")
plt.xscale("log")
# plt.yscale('log')
plt.legend()

d.assign(gene=lambda x: x[_gene] > thresh)[
    ["ibd_diagnosis", "gene"]
].value_counts().unstack()

In [None]:
# When we just grab gene presence absence in each subject using the depth ratio alone, can we still detect effects?
subject_strain_gene_content2 = (
    subject_uhgg_depth_ratio.T.loc[subject_strain_gene_content.index].replace(
        {np.nan: 0, np.inf: 0, -np.inf: 0}
    )
    > 0.1
)

sp.stats.pearsonr(
    subject_strain_gene_content2[subject_strain_gene_content.columns]
    .astype(float)
    .values.flatten(),
    subject_strain_gene_content.astype(float).values.flatten(),
)

In [None]:
mwas_results2 = []
for _gene in tqdm(strain_gene_uhgg.index):
    contingency_table0 = (
        pd.DataFrame(
            dict(
                diagnosis=subject_ibd_diagnosis,
                gene=subject_strain_gene_content2.loc[_gene],
            )
        )
        .value_counts()
        .unstack("diagnosis")
        .reindex(columns=["CD", "UC", "nonIBD"])
        .assign(IBD=lambda x: x.CD + x.UC)
    )
    for ibd_comparison, (groupA, groupB) in dict(
        cd=("nonIBD", "CD"), uc=("nonIBD", "UC"), ibd=("nonIBD", "IBD")
    ).items():
        contingency_table1 = contingency_table0.reindex(
            index=[True, False], columns=[groupA, groupB]
        ).fillna(0)
        _test = sp.stats.fisher_exact(contingency_table1)
        contingency_table_pc = contingency_table1 + 1
        oddsratio_pc = (
            contingency_table_pc.loc[True, groupB]
            / contingency_table_pc.loc[False, groupB]
        ) / (
            contingency_table_pc.loc[True, groupA]
            / contingency_table_pc.loc[False, groupA]
        )
        mwas_results2.append((ibd_comparison, _gene, *_test, oddsratio_pc))

mwas_results2 = (
    pd.DataFrame(
        mwas_results2,
        columns=["ibd_comparison", "gene", "statistic", "pvalue", "oddsratio_pc"],
    )
    # .join(
    #     gene_annotations[
    #         [
    #             "Description",
    #             "COG_category",
    #             "eggNOG_OGs",
    #             "centroid_99_length",
    #             "score",
    #             "Preferred_name",
    #             "KEGG_ko",
    #             "PFAMs",
    #         ]
    #     ],
    #     on="gene",
    # )
    .sort_values("pvalue")
    .set_index(['gene', 'ibd_comparison'])
)

mwas_results2.head(10)

In [None]:
x, y = mwas_results.oddsratio_pc.apply(np.log), mwas_results2.oddsratio_pc.apply(np.log)
print(sp.stats.pearsonr(x, y))
plt.hist2d(x, y, bins=15, norm=mpl.colors.PowerNorm(1 / 5))
None

In [None]:
x, y = mwas_results2.pvalue.apply(np.log), mwas_results.pvalue.apply(np.log)
print(sp.stats.pearsonr(x, y))
plt.hist2d(x, y, bins=15, norm=mpl.colors.PowerNorm(1 / 5))
plt.colorbar()
None

## Write Statistics

In [None]:
gene_meta = (
    pd.DataFrame(
        dict(
            nlength=gene_annotations["centroid_99_length"],
            eggnog=gene_annotations["eggNOG_OGs"],
            cog_categories=gene_annotations["COG_category"],
            description=gene_annotations["Description"],
            gene_name=gene_annotations["Preferred_name"],
            ko=gene_annotations["KEGG_ko"],
            pfam=gene_annotations["PFAMs"],
        )
    )
    .join(gene_x_cog_category_matrix.rename(columns=lambda x: "cog_category_" + x))
    .rename_axis(index="gene_id")
)

gene_meta.to_csv(gene_meta_outpath, sep="\t")

In [None]:
gene_stats = (
    pd.DataFrame(
        dict(
            prevalence_spgc=spgc_uhgg_prevalence,
            prevalence_ref=ref_uhgg_prevalence,
            prevalence_mwas_subject=subject_strain_gene_prevalence,
            num_mwas_subject=num_strain_subjects,
            phylogenetic_i_spgc=phylogenetic_signal_spgc,
            phylogenetic_i_ref=phylogenetic_signal_ref,
            coclust_label_spgc=clust1,
            # coclust_label_both=clust2,
            coclust_label_ref=clust3,
        )
    )
    .join(
        mwas_results
        .pvalue.unstack()
        .rename(columns=lambda x: x + "_mwas_pvalue")
    )
    .join(
        mwas_results
        .oddsratio_pc.unstack()
        .rename(columns=lambda x: x + "_mwas_oddsratio_pc")
    )
    .rename_axis(index="gene_id")
)

gene_stats.to_csv(gene_stats_outpath, sep="\t")
gene_stats

In [None]:
gene_stats.head()

In [None]:
spgc_strain_stats = (
    pd.DataFrame(
        dict(
            nearest_ref_geno=genome_distance_comparison_filt_unweighted_jacc[
                lambda x: x.genome_type.isin(["SPGC"])
            ]["min_geno_ref"],
            nearest_ref_geno_diss=genome_distance_comparison_filt_unweighted_jacc[
                lambda x: x.genome_type.isin(["SPGC"])
            ]["geno_dist"],
            nearest_ref_gene_diss=genome_distance_comparison_filt_unweighted_jacc[
                lambda x: x.genome_type.isin(["SPGC"])
            ]["gene_dist"],
        )
    )
    .join(spgc_meta.loc[spgc_list])
    .assign(derep_clust=strain_derep_clust.loc[spgc_list])
    .join(
        strain_genome_spgc_prevalence_class_tally.rename(
            columns=lambda x: f"spgc_{x}_gene_tally"
        )
    )
    .join(
        strain_genome_ref_prevalence_class_tally.rename(
            columns=lambda x: f"ref_{x}_gene_tally"
        )
    )
    .rename_axis(index="strain")
)

spgc_strain_stats.to_csv(spgc_strain_stats_outpath, sep="\t")
spgc_strain_stats

In [None]:
ref_strain_stats = (
    pd.DataFrame(
        dict(
            nearest_ref_geno=genome_distance_comparison_filt_unweighted_jacc[
                lambda x: x.genome_type.isin(["Isolate", "MAG"])
            ]["min_geno_ref"],
            nearest_ref_geno_diss=genome_distance_comparison_filt_unweighted_jacc[
                lambda x: x.genome_type.isin(["Isolate", "MAG"])
            ]["geno_dist"],
            nearest_ref_gene_diss=genome_distance_comparison_filt_unweighted_jacc[
                lambda x: x.genome_type.isin(["Isolate", "MAG"])
            ]["gene_dist"],
            genome_type=genome_distance_comparison_filt_unweighted_jacc[
                lambda x: x.genome_type.isin(["Isolate", "MAG"])
            ]["genome_type"],
        )
    )
    .join(
        reference_meta[
            [
                "Length",
                "N_contigs",
                "N50",
                "GC_content",
                "Completeness",
                "Contamination",
            ]
        ]
    )
    .assign(derep_clust=strain_derep_clust.loc[ref_list])
    .join(
        strain_genome_spgc_prevalence_class_tally.rename(
            columns=lambda x: f"spgc_{x}_gene_tally"
        )
    )
    .join(
        strain_genome_ref_prevalence_class_tally.rename(
            columns=lambda x: f"ref_{x}_gene_tally"
        )
    )
    .rename_axis(index="strain")
)

ref_strain_stats.to_csv(ref_strain_stats_outpath, sep="\t")
ref_strain_stats

## SPGC-MWAS Experiment (antibiotics)

In [None]:
visit = pd.read_table("meta/hmp2/visit.tsv", index_col="visit_id")

In [None]:
from sklearn.linear_model import Lasso, LassoCV, LogisticRegressionCV

# _row_linkage = gene_uhgg_linkage
_strain_list = spgc_list
_strain_to_gene = strain_gene_uhgg[_strain_list]
_strain_present = sfacts_fit.community.to_pandas()[spgc_list] > 0.2
_strain_present = _strain_present[_strain_present.T.any()]

_meta = mgen_meta.join(visit, on="visit_id", rsuffix="_")
_meta, _strain_present = lib.pandas_util.align_indexes(_meta, _strain_present)

# Multiplied together gives the subject by gene mapping
sample_gene_content = (_strain_present @ _strain_to_gene.T) > 0

# # Although I guess some subjects may have just a small subset of strains with that gene.
# sample_strain_gene_prevalence = subject_strain_gene_content.mean(1)
# lm.score(sample_gene_content.astype(float), _meta.status_antibiotics.astype(float))

In [None]:
lm = LogisticRegressionCV(Cs=[1e-0], cv=5, penalty="l1", solver="liblinear").fit(
    sample_gene_content.astype(float), _meta.status_antibiotics.astype(float)
)
lm.score(sample_gene_content.astype(float), _meta.status_antibiotics.astype(float))

In [None]:
plt.hist(lm.coef_.flatten())
plt.yscale("log")

In [None]:
pd.Series(lm.coef_.flatten(), index=sample_gene_content.columns).sort_values().tail(20)

In [None]:
gene_meta.loc[
    pd.Series(lm.coef_.flatten(), index=sample_gene_content.columns)
    .sort_values()
    .head(20)
    .index
]