## Preamble

### Project Template

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

### Imports

In [None]:
import os
import subprocess
import sys
import time
from datetime import datetime
from glob import glob
from itertools import chain, product
from tempfile import mkstemp

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import statsmodels.api as sm
import statsmodels.formula.api as smf
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

In [None]:
import lib.thisproject.data

### Set Style

In [None]:
sns.set_context("talk")
plt.rcParams["figure.dpi"] = 50

## Data Setup

### Metadata

In [None]:
species_taxonomy_inpath = "ref/gtpro/species_taxonomy_ext.tsv"

species_taxonomy = lib.thisproject.data.load_species_taxonomy(species_taxonomy_inpath)
species_taxonomy

In [None]:
cog_category_description = pd.read_table(
    "ref/cog-20.categories.tsv",
    names=["cog_category", "description"],
    index_col="cog_category",
).assign(description=lambda x: x.index + ": " + x.description)
cog_category_description.loc["no_category", "description"] = "-: No Annotation"
cog_category_description

In [None]:
cog_category_order = [
    "A",
    "B",
    "C",
    "D",
    "E",
    "F",
    "G",
    "H",
    "I",
    "J",
    "K",
    "L",
    "M",
    "N",
    "O",
    "P",
    "Q",
    "R",
    "S",
    "T",
    "U",
    "V",
    "W",
    "X",
    "Y",
    "Z",
    "no_category",
]

cog_category_raw_columns = [f"cog_category_{s}" for s in cog_category_order]

prevalence_class_order = ["core", "shell", "cloud"]

In [None]:
species_list = (
    pd.read_table("meta/species_group.tsv")[
        lambda x: x.species_group_id == "xjin_ucfmt_hmp2"
    ]
    .species_id.astype(str)
    .unique()
)

### Strain Statistics

In [None]:
spgc_strain_stats = []
ref_strain_stats = []
missing_strain_stats = []
for species in species_list:
    spgc_path = f"data/group/xjin_ucfmt_hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_new-v22-agg75.spgc-fit.spgc_strain_stats.tsv"
    ref_path = f"data/group/xjin_ucfmt_hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_new-v22-agg75.spgc-fit.ref_strain_stats.tsv"
    if not os.path.exists(spgc_path):
        missing_strain_stats.append(species)
        continue
    d_spgc = pd.read_table(spgc_path)
    d_ref = pd.read_table(ref_path)
    spgc_strain_stats.append(d_spgc.assign(species=species))
    ref_strain_stats.append(d_ref.assign(species=species))
spgc_strain_stats = pd.concat(spgc_strain_stats).join(species_taxonomy, on="species")
ref_strain_stats = pd.concat(ref_strain_stats).join(species_taxonomy, on="species")


print(
    len(missing_strain_stats),
    "out of",
    len(species_list),
    "species are missing spgc strain stats.",
)
print(
    len(spgc_strain_stats),
    "spgc strains found across",
    len(spgc_strain_stats.species.unique()),
    "species",
)
print(
    len(ref_strain_stats),
    "ref strains found across",
    len(ref_strain_stats.species.unique()),
    "species",
)

In [None]:
phylum_palette = lib.plot.construct_ordered_palette(
    sorted(ref_strain_stats.sort_values("taxonomy_string").p__.unique()),
    cm="tab10",
)

assert len(set(phylum_palette.values())) == len((phylum_palette.values()))

### Gene Statistics

In [None]:
mwas_stats = []
missing_mwas_stats = []
for species in tqdm(species_list):
    mwas_path = f"data/group/xjin_ucfmt_hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_new-v22-agg75.spgc-fit.strain_gene.hmp2_mwas-f30-n2.tsv"
    if not os.path.exists(mwas_path):
        missing_mwas_stats.append(species)
        continue
    d = pd.read_table(
        mwas_path
    )  # TODO: Drop the species mapping. We'll get that from the metadata table.
    mwas_stats.append(d.assign(species_id=species))
mwas_stats = (
    pd.concat(mwas_stats)
    .set_index('gene_id')
    .assign(
        total_present=lambda x: x[["present-CD", "present-UC", "present-nonIBD"]].sum(
            1
        ),
        total_absent=lambda x: x[["absent-CD", "absent-UC", "absent-nonIBD"]].sum(1),
        total_subjects=lambda x: x[["total_present", "total_absent"]].sum(1),
        subject_prevalence=lambda x: x['total_present'] / (x['total_present'] + x['total_absent']),
        log10_fisher_exact_pvalue_ibd=lambda x: np.log10(x.fisher_exact_pvalue_ibd),
    )
    .sort_values("fisher_exact_pvalue_ibd")
)
print(len(missing_mwas_stats), "species are missing mwas stats")
mwas_stats

In [None]:
gene_stats = []
_gene_meta = []
missing_gene_stats = []
for species in tqdm(species_list):
    stats_path = f"data/group/xjin_ucfmt_hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_new-v22-agg75.spgc-fit.gene_stats.tsv"
    meta_path = f"data/group/xjin_ucfmt_hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_new-v22-agg75.spgc-fit.gene_meta.tsv"
    if not os.path.exists(stats_path):
        missing_gene_stats.append(species)
        continue
    d_stats = pd.read_table(stats_path)
    d_meta = pd.read_table(meta_path)
    gene_stats.append(d_stats.assign(species=species))
    _gene_meta.append(d_meta.assign(species=species))
gene_stats = pd.concat(gene_stats).set_index("gene_id")
_gene_meta = pd.concat(_gene_meta).set_index("gene_id")

In [None]:
gene_x_cog_category_matrix = (
    _gene_meta.loc[:, cog_category_raw_columns]
    .rename(columns=lambda s: s[len("cog_category_") :])
    .fillna(False)
)

gene_meta = _gene_meta.drop(columns=cog_category_raw_columns, errors='ignore')

## Analysis

### Novelty / Believability / Variability

In [None]:
plt.hist2d(
    spgc_strain_stats.nearest_ref_geno_diss,
    spgc_strain_stats.nearest_ref_gene_diss,
    norm=mpl.colors.PowerNorm(1 / 3),
    bins=(
        [0] + list(np.logspace(-4, 0, num=40)),
        np.linspace(0, 1, num=40),
    ),
)
plt.xscale("symlog", linthresh=1e-4, linscale=0.1)

In [None]:
plt.hist2d(
    ref_strain_stats.nearest_ref_geno_diss,
    ref_strain_stats.nearest_ref_gene_diss,
    norm=mpl.colors.PowerNorm(1 / 3),
    bins=(
        [0] + list(np.logspace(-4, 0, num=40)),
        np.linspace(0, 1, num=40),
    ),
)
plt.xscale("symlog", linthresh=1e-4, linscale=0.1)

In [None]:
# Number of strains in each group after de-replication.
pd.concat(
    [
        spgc_strain_stats[["species", "derep_clust"]].assign(genome_type="SPGC"),
        ref_strain_stats[["species", "derep_clust", "genome_type"]],
    ]
).groupby("species").value_counts().unstack("genome_type", fill_value=0).gt(0).sum()

In [None]:
## Entirely new strains "discovered" in my pipeline.

d = (
    pd.concat(
        [
            spgc_strain_stats[["species", "derep_clust"]].assign(genome_type="SPGC"),
            ref_strain_stats[["species", "derep_clust", "genome_type"]],
        ]
    )
    .groupby("species")
    .value_counts()
    .unstack("genome_type", fill_value=0)
    .gt(0)
    .groupby(level="species")
    .value_counts()
    .unstack("species", fill_value=0)
    .T
)

d

In [None]:
num_novel_strains = d[(False, False, True)]
num_existing_strains = d.sum(1) - num_novel_strains
plt.scatter(num_existing_strains, num_novel_strains)
plt.yscale("symlog")
plt.xscale("symlog")

In [None]:
print(sp.stats.pearsonr(gene_stats.prevalence_ref, gene_stats.prevalence_spgc))
plt.hist2d(
    gene_stats.prevalence_ref,
    gene_stats.prevalence_spgc,
    norm=mpl.colors.PowerNorm(1 / 3),
)
None

In [None]:
d0 = spgc_strain_stats.sort_values("taxonomy_string")
focal_species = "102438"

fit = smf.ols(
    f"nearest_ref_gene_diss ~ np.log10(nearest_ref_geno_diss) + C(species, Sum)",
    data=d0,
).fit()

fig, axs = plt.subplots(1, 2, figsize=(14, 5), sharey=True)

_color_palette = lib.plot.construct_ordered_palette_from_list(
    ["102506"], colors=["red"], other="grey"
)
_size_palette = lib.plot.construct_ordered_palette_from_list(
    ["102506"], colors=[20], other=5
)

ax = axs[0]
# for species, d1 in d0.groupby("species"):
ax.scatter(
    "nearest_ref_geno_diss",
    "nearest_ref_gene_diss",
    data=d0,
    label="__nolegend__",
    color="grey",
    s=2,
    alpha=0.5,
)
ax.scatter(
    "nearest_ref_geno_diss",
    "nearest_ref_gene_diss",
    data=d0[d0.species == focal_species],
    label=focal_species,
    color="red",
    s=20,
    alpha=0.5,
)
ax.legend(loc="upper right")
# ax.scatter([], [], label=p__, color=phylum_palette[p__], s=10)

xx = np.logspace(-3, -1)
ax.plot(
    xx,
    np.log10(xx) * fit.params["np.log10(nearest_ref_geno_diss)"]
    + fit.params["Intercept"],
    color="k",
    label="Overall Slope",
)
p_slope = fit.pvalues["np.log10(nearest_ref_geno_diss)"]
slope = fit.params["np.log10(nearest_ref_geno_diss)"]  # TODO
r2_adj = fit.rsquared_adj
ax.annotate(
    f"$\\beta_{{\\mathrm{{slope}}}}$={slope:0.3f}\n$P_{{\\mathrm{{slope}}}}$={p_slope:0.0e}\n$R^2_{{\\mathrm{{adj}}}}$={r2_adj:0.2%}",
    xy=(0.1, 0.9),
    xycoords="axes fraction",
    va="top",
)
ax.annotate(f"", xy=(0.1, 0.9), xycoords="axes fraction")

# ax.plot("nearest_ref_geno_diss", "gene_dist_predict", data=d2, label="__nolegend__")
# ax.legend(bbox_to_anchor=(1, 1), markerscale=4)
ax.set_xscale("symlog", linthresh=1e-4)

ax.set_xlabel("SNP Profile Dissimilarity")
ax.set_ylabel("Gene Content Dissimilarity\n(filtered, Jaccard)")
# plt.scatter(spgc_strain_stats.nearest_ref_geno_diss, spgc_strain_stats.nearest_ref_gene_raw_diss, s=1)

ax = axs[1]
sns.boxplot(x="p__", y="nearest_ref_gene_diss", data=d0, ax=ax, palette=phylum_palette)
lib.plot.rotate_xticklabels(ax=ax)
ax.set_ylabel("")

ax.set_ylim(-0.005, 0.605)

In [None]:
d0 = ref_strain_stats.sort_values("taxonomy_string")
focal_species = "102438"

fit = smf.ols(
    f"nearest_ref_gene_diss ~ np.log10(nearest_ref_geno_diss) + C(species, Sum)",
    data=d0,
).fit()

fig, axs = plt.subplots(1, 2, figsize=(14, 5), sharey=True)

_color_palette = lib.plot.construct_ordered_palette_from_list(
    ["102506"], colors=["red"], other="grey"
)
_size_palette = lib.plot.construct_ordered_palette_from_list(
    ["102506"], colors=[20], other=5
)

ax = axs[0]
# for species, d1 in d0.groupby("species"):
ax.scatter(
    "nearest_ref_geno_diss",
    "nearest_ref_gene_diss",
    data=d0,
    label="__nolegend__",
    color="grey",
    s=2,
    alpha=0.5,
)
ax.scatter(
    "nearest_ref_geno_diss",
    "nearest_ref_gene_diss",
    data=d0[d0.species == focal_species],
    label=focal_species,
    color="red",
    s=20,
    alpha=0.5,
)
ax.legend(loc="upper right")
# ax.scatter([], [], label=p__, color=phylum_palette[p__], s=10)

xx = np.logspace(-3, -1)
ax.plot(
    xx,
    np.log10(xx) * fit.params["np.log10(nearest_ref_geno_diss)"]
    + fit.params["Intercept"],
    color="k",
    label="Overall Slope",
)
p_slope = fit.pvalues["np.log10(nearest_ref_geno_diss)"]
slope = fit.params["np.log10(nearest_ref_geno_diss)"]  # TODO
r2_adj = fit.rsquared_adj
ax.annotate(
    f"$\\beta_{{\\mathrm{{slope}}}}$={slope:0.3f}\n$P_{{\\mathrm{{slope}}}}$={p_slope:0.0e}\n$R^2_{{\\mathrm{{adj}}}}$={r2_adj:0.2%}",
    xy=(0.1, 0.9),
    xycoords="axes fraction",
    va="top",
)
ax.annotate(f"", xy=(0.1, 0.9), xycoords="axes fraction")

# ax.plot("nearest_ref_geno_diss", "gene_dist_predict", data=d2, label="__nolegend__")
# ax.legend(bbox_to_anchor=(1, 1), markerscale=4)
ax.set_xscale("symlog", linthresh=1e-4)

ax.set_xlabel("SNP Profile Dissimilarity")
ax.set_ylabel("Gene Content Dissimilarity\n(filtered, Jaccard)")
# plt.scatter(spgc_strain_stats.nearest_ref_geno_diss, spgc_strain_stats.nearest_ref_gene_raw_diss, s=1)

ax = axs[1]
sns.boxplot(x="p__", y="nearest_ref_gene_diss", data=d0, ax=ax, palette=phylum_palette)
lib.plot.rotate_xticklabels(ax=ax)
ax.set_ylabel("")

ax.set_ylim(-0.005, 0.605)

In [None]:
d0 = spgc_strain_stats.sort_values("taxonomy_string")
fit = smf.ols(
    f"nearest_ref_gene_diss ~ cr(nearest_ref_geno_diss, 20, constraints='center') + C(species, Sum)",
    data=d0,
).fit()
print(fit.aic)

fig, axs = plt.subplots(1, 2, figsize=(14, 5), sharey=True)

ax = axs[0]
ax.scatter(
    "nearest_ref_geno_diss",
    "nearest_ref_gene_diss",
    data=d0,
    label="__nolegend__",
    color="grey",
    s=2,
    alpha=0.5,
)

_arbitrary_species = "100003"
_arbitrary_species_term = f"C(species, Sum)[S.{_arbitrary_species}]"
xx = np.logspace(-4, -1)
yy = (
    fit.predict(
        pd.DataFrame(dict(nearest_ref_geno_diss=xx)).assign(species=_arbitrary_species)
    ).values
    - fit.params[_arbitrary_species_term]
)
ax.plot(
    xx,
    yy,
    color="k",
)

# ax.plot("nearest_ref_geno_diss", "gene_dist_predict", data=d2, label="__nolegend__")
# ax.legend(bbox_to_anchor=(1, 1), markerscale=4)
ax.set_xscale("symlog", linthresh=1e-4)

ax.set_xlabel("SNP Profile Dissimilarity")
ax.set_ylabel("Gene Content Dissimilarity\n(filtered, Jaccard)")
# plt.scatter(spgc_strain_stats.nearest_ref_geno_diss, spgc_strain_stats.nearest_ref_gene_raw_diss, s=1)


ax = axs[1]
_species_list = d0.species.unique()
d1 = (
    pd.Series(
        fit.predict(
            pd.DataFrame(dict(species=_species_list)).assign(nearest_ref_geno_diss=1e-3)
        ).values,
        index=_species_list,
    )
    .to_frame(name="species_min_predicted_minimum_gene_diss")
    .join(species_taxonomy)
)
sns.boxplot(
    x="p__",
    y="species_min_predicted_minimum_gene_diss",
    data=d1,
    ax=ax,
    palette=phylum_palette,
)
lib.plot.rotate_xticklabels(ax=ax)
ax.set_ylabel("")
ax.set_ylim(-0.005, 0.605)

spgc_near_neighbor_predicted_gene_diss = d1.species_min_predicted_minimum_gene_diss

In [None]:
d0 = ref_strain_stats.sort_values("taxonomy_string")
fit = smf.ols(
    f"nearest_ref_gene_diss ~ cr(nearest_ref_geno_diss, 20, constraints='center') + C(species, Sum)",
    data=d0,
).fit()
print(fit.aic)

fig, axs = plt.subplots(1, 2, figsize=(14, 5), sharey=True)

ax = axs[0]
ax.scatter(
    "nearest_ref_geno_diss",
    "nearest_ref_gene_diss",
    data=d0,
    label="__nolegend__",
    color="grey",
    s=2,
    alpha=0.5,
)

_arbitrary_species = "100003"
_arbitrary_species_term = f"C(species, Sum)[S.{_arbitrary_species}]"
xx = np.logspace(-4, -1)
yy = (
    fit.predict(
        pd.DataFrame(dict(nearest_ref_geno_diss=xx)).assign(species=_arbitrary_species)
    ).values
    - fit.params[_arbitrary_species_term]
)
ax.plot(
    xx,
    yy,
    color="k",
)

# ax.plot("nearest_ref_geno_diss", "gene_dist_predict", data=d2, label="__nolegend__")
# ax.legend(bbox_to_anchor=(1, 1), markerscale=4)
ax.set_xscale("symlog", linthresh=1e-4)

ax.set_xlabel("SNP Profile Dissimilarity")
ax.set_ylabel("Gene Content Dissimilarity\n(filtered, Jaccard)")
# plt.scatter(spgc_strain_stats.nearest_ref_geno_diss, spgc_strain_stats.nearest_ref_gene_raw_diss, s=1)


ax = axs[1]
_species_list = d0.species.unique()
d1 = (
    pd.Series(
        fit.predict(
            pd.DataFrame(dict(species=_species_list)).assign(nearest_ref_geno_diss=1e-3)
        ).values,
        index=_species_list,
    )
    .to_frame(name="species_min_predicted_minimum_gene_diss")
    .join(species_taxonomy)
)
sns.boxplot(
    x="p__",
    y="species_min_predicted_minimum_gene_diss",
    data=d1,
    ax=ax,
    palette=phylum_palette,
)
lib.plot.rotate_xticklabels(ax=ax)
ax.set_ylabel("")
ax.set_ylim(-0.005, 0.605)

ref_near_neighbor_predicted_gene_diss = d1.species_min_predicted_minimum_gene_diss

In [None]:
# {ref,spgc}_near_neighbor_predicted_gene_diss is the predicted
# gene dissimilarity at a genotype distance of 1e-3 (very closely related strains)
# in just about any species.
# Here we see both the relatively large mean of these values (suggesting that
# going with the DBs nearest neighbor is fraught and gene content dissimilarity
# really matters), the high degree of correlation (suggesting that SPGC
# estimates are a truthfull view on reality), and the
# high variance (suggesting that different taxa behave differently in this
# regard.

fig, ax = plt.subplots(figsize=(10, 10))

d0 = pd.DataFrame(
    dict(
        ref=ref_near_neighbor_predicted_gene_diss,
        spgc=spgc_near_neighbor_predicted_gene_diss,
    )
).join(species_taxonomy)
for p__, d1 in d0.groupby("p__"):
    ax.scatter(
        "ref", "spgc", data=d1, color=phylum_palette[p__], s=50, alpha=0.8, label=p__
    )
ax.legend()
ax.plot([0, 0.2], [0, 0.2], lw=1, linestyle="--", color="k")
ax.set_xlim(-0.04, 0.42)
ax.set_ylim(-0.04, 0.42)
ax.set_aspect(1)

In [None]:
d0 = pd.concat(
    [ref_strain_stats.assign(is_spgc=False), spgc_strain_stats.assign(is_spgc=True)]
).sort_values("taxonomy_string")
fit = smf.ols(
    f"nearest_ref_gene_diss ~ cr(nearest_ref_geno_diss, 20, constraints='center') + C(species, Sum) + is_spgc",
    data=d0,
).fit()
print(fit.aic)

fig, axs = plt.subplots(1, 2, figsize=(14, 5), sharey=True)

ax = axs[0]
ax.scatter(
    "nearest_ref_geno_diss",
    "nearest_ref_gene_diss",
    data=d0,
    label="__nolegend__",
    color="grey",
    s=2,
    alpha=0.5,
)

_arbitrary_species = "100003"
_arbitrary_species_term = f"C(species, Sum)[S.{_arbitrary_species}]"
xx = np.logspace(-4, -1)
yy_spgc = (
    fit.predict(
        pd.DataFrame(dict(nearest_ref_geno_diss=xx)).assign(
            species=_arbitrary_species, is_spgc=True
        )
    ).values
    - fit.params[_arbitrary_species_term]
)
yy_ref = (
    fit.predict(
        pd.DataFrame(dict(nearest_ref_geno_diss=xx)).assign(
            species=_arbitrary_species, is_spgc=False
        )
    ).values
    - fit.params[_arbitrary_species_term]
)
ax.plot(
    xx,
    yy_spgc,
    color="darkred",
)
ax.plot(
    xx,
    yy_ref,
    color="k",
)

# ax.plot("nearest_ref_geno_diss", "gene_dist_predict", data=d2, label="__nolegend__")
# ax.legend(bbox_to_anchor=(1, 1), markerscale=4)
ax.set_xscale("symlog", linthresh=1e-4)

ax.set_xlabel("SNP Profile Dissimilarity")
ax.set_ylabel("Gene Content Dissimilarity\n(filtered, Jaccard)")
# plt.scatter(spgc_strain_stats.nearest_ref_geno_diss, spgc_strain_stats.nearest_ref_gene_raw_diss, s=1)


ax = axs[1]
_species_list = d0.species.unique()
d1 = (
    pd.Series(
        fit.predict(
            pd.DataFrame(dict(species=_species_list)).assign(
                nearest_ref_geno_diss=1e-4, is_spgc=False
            )
        ).values,
        index=_species_list,
    )
    .to_frame(name="species_min_predicted_minimum_gene_diss")
    .join(species_taxonomy)
)
sns.boxplot(
    x="p__",
    y="species_min_predicted_minimum_gene_diss",
    data=d1,
    ax=ax,
    palette=phylum_palette,
)
lib.plot.rotate_xticklabels(ax=ax)
ax.set_ylabel("")
ax.set_ylim(-0.005, 0.605)

In [None]:
d0 = (
    spgc_strain_stats.assign(
        nearest_ref_gene_diss_adj=lambda x: x.nearest_ref_gene_diss
        - fit.params["is_spgc[T.True]"]
    )
    .groupby("species")[
        ["nearest_ref_geno_diss", "nearest_ref_gene_diss", "nearest_ref_gene_diss_adj"]
    ]
    .median()
    .join(species_taxonomy)
)

fig, axs = plt.subplots(2, sharex=True, figsize=(7, 10))
sns.stripplot(
    x="p__",
    y="nearest_ref_geno_diss",
    data=d0,
    size=10,
    alpha=0.7,
    palette=phylum_palette,
    ax=axs[0],
)
sns.stripplot(
    x="p__",
    y="nearest_ref_gene_diss_adj",
    data=d0,
    size=10,
    alpha=0.7,
    palette=phylum_palette,
    ax=axs[1],
)
lib.plot.rotate_xticklabels(ax=axs[1])
# ax.set_ylim(-0.005, 0.605)

### strain-gene-MWAS

In [None]:
mwas_filt_func = lambda x: (x.subject_prevalence < 0.8) & (x.subject_prevalence > 0.2) & (x.total_subjects > 40)

In [None]:
mwas_stats[mwas_filt_func].sort_values('fisher_exact_pvalue_ibd')

In [None]:
bins = np.logspace(-6, 0, num=20)
plt.hist(
    mwas_stats[mwas_filt_func].fisher_exact_pvalue_ibd,
    alpha=0.5,
    bins=bins,
)
print(mwas_stats[mwas_filt_func].shape[0])
plt.plot(bins[1:], (bins[1:] - bins[:-1]) * mwas_stats[mwas_filt_func].shape[0])

plt.xscale("log")
plt.yscale("log")

In [None]:
# mwas_stats[mwas_filt_func].sort_values("fisher_exact_pvalue_ibd").head(50)#.join(species_taxonomy, on='species_id')

In [None]:
fig, ax = plt.subplots()
ax.scatter('oddsratio_pc_ibd', 'log10_fisher_exact_pvalue_ibd', data=mwas_stats[mwas_filt_func], s=5)
ax.set_xscale('log')
ax.invert_yaxis()

In [None]:
plt.plot(mwas_stats[mwas_filt_func].log10_fisher_exact_pvalue_ibd.sort_values().values)

In [None]:
mwas_stats[mwas_filt_func].join(gene_meta)[
lambda x: (x.fisher_exact_pvalue_ibd < 1e-3)
# & (x.species_id == '101345')
].head(50)

In [None]:
mwas_stats[mwas_filt_func].join(gene_meta)[lambda x: x.fisher_exact_pvalue_ibd < 1e-3].species.value_counts().to_frame().join(species_taxonomy[['f__', 'g__', 's__']]).assign(total=mwas_stats[mwas_filt_func].join(gene_meta).species.value_counts())

In [None]:
species_id = '101338'
d = pd.DataFrame(dict(
    in_G=mwas_stats[mwas_filt_func][lambda x: (x.species_id == species_id)].join(gene_meta).cog_categories.fillna('').str.contains('G'),
    is_signif=mwas_stats[mwas_filt_func][lambda x: (x.species_id == species_id)].fisher_exact_pvalue_ibd < 1e-3,
))
contingency = d.value_counts().unstack()
print(sp.stats.fisher_exact(contingency))
contingency

In [None]:
d = pd.DataFrame(dict(
    in_G=mwas_stats[mwas_filt_func].join(gene_meta).cog_categories.fillna('').str.contains('G'),
    is_signif=mwas_stats[mwas_filt_func].fisher_exact_pvalue_ibd < 1e-3,
))
contingency = d.value_counts().unstack()
print(sp.stats.fisher_exact(contingency))
contingency

In [None]:
mwas_stats[mwas_filt_func].join(gene_meta).iloc[0]

In [None]:
bins = np.logspace(-7, 0, num=50)
# bins = np.linspace(0, 1, num=51)

x = gene_stats[
    lambda x: (x.prevalence_mwas_subject > 0.2)
    & (x.prevalence_mwas_subject < 0.8)
    & (x.num_mwas_subject > 20)
].ibd_mwas_pvalue.sort_values()
print(len(x))
plt.hist(
    x,
    alpha=0.5,
    bins=bins,
)

x = gene_stats[
    lambda x: (x.prevalence_mwas_subject > 0.2)
    & (x.prevalence_mwas_subject < 0.8)
    & (x.num_mwas_subject > 40)
].ibd_mwas_pvalue.sort_values()
print(len(x))
plt.hist(
    x,
    alpha=0.5,
    bins=bins,
)

plt.plot(bins[1:], (bins[1:] - bins[:-1]) * x.shape)

plt.xscale("log")
plt.yscale("log")

In [None]:
plt.scatter(
    "ibd_mwas_oddsratio_pc",
    "ibd_mwas_pvalue",
    data=gene_stats[
        lambda x: (x.prevalence_mwas_subject > 0.2)
        & (x.prevalence_mwas_subject < 0.8)
        & (x.num_mwas_subject > 20)
    ],
    s=4,
    alpha=0.5,
)
plt.xscale("log")
plt.yscale("log")
ax = plt.gca()
ax.invert_yaxis()
plt.xlabel("Odds-Ratio (w/ pseudo-counts)")
plt.ylabel("P-value")

In [None]:
d0 = (
    gene_stats[
        lambda x: (x.prevalence_mwas_subject > 0.2)
        & (x.prevalence_mwas_subject < 0.8)
        & (x.num_mwas_subject > 20)
    ]
    .assign(fdr=lambda x: fdrcorrection(x.ibd_mwas_pvalue)[1])
    .join(species_taxonomy, on="species")
    .sort_values("ibd_mwas_pvalue")
)

plt.hist2d(
    "ibd_mwas_oddsratio_pc",
    "ibd_mwas_pvalue",
    data=d0,
    bins=(np.logspace(-3, 3, num=100), np.logspace(-7, 0, num=100)),
    norm=mpl.colors.SymLogNorm(linthresh=2, linscale=3, vmin=1, vmax=1e5),
    cmin=1,
    cmap="copper_r",
)

plt.colorbar()

plt.scatter(
    "ibd_mwas_oddsratio_pc",
    "ibd_mwas_pvalue",
    data=d0[(d0.fdr < 0.2)],
    s=20,
    marker="o",
    edgecolor="k",
    facecolor="none",
    lw=1,
    alpha=0.5,
)
plt.yscale("log")
plt.xscale("log")
# plt.xscale('log')
plt.gca().invert_yaxis()
plt.xlabel("Log(Odds Ratio)")
plt.ylabel("P-value")

# d1 = d0[lambda x: (x.gene == 'UHGG000638_01150') & (x.pvalue < 1e-3)]
# for _, d2 in d1.iterrows():
#     plt.annotate('', xy=(d2.log_logratio_pc, d2.pvalue), ha='left', va='bottom', xytext=(10, 10), textcoords="offset points", arrowprops=dict())
# d1

high_confidence_mwas_hit = idxwhere(d0.fdr < 0.2)
lowish_confidence_mwas_hit = idxwhere(d0.fdr < 0.3)
lower_confidence_mwas_hit = idxwhere(d0.fdr < 0.4)
low_confidence_mwas_hit = idxwhere(d0.ibd_mwas_pvalue < 0.01)

print(len(high_confidence_mwas_hit), len(lower_confidence_mwas_hit), len(
    low_confidence_mwas_hit))

In [None]:
(
    gene_meta.loc[high_confidence_mwas_hit]
    .assign(mwas_ibd_pos_assoc=(gene_stats.ibd_mwas_oddsratio_pc > 1).replace({True: "pos", False: "neg"}))[
        ["species", "mwas_ibd_pos_assoc"]
    ]
    .value_counts()
    .unstack(fill_value=0)
    .assign(total_hits=lambda x: x.neg + x.pos)
    .join(species_taxonomy[["g__", "s__", "taxonomy_string"]])
    .sort_values("total_hits", ascending=False)
    .head(10)
)

In [None]:
(
    gene_meta.loc[low_confidence_mwas_hit]
    .assign(mwas_ibd_pos_assoc=(gene_stats.ibd_mwas_oddsratio_pc > 1).replace({True: "pos", False: "neg"}))[
        ["species", "mwas_ibd_pos_assoc"]
    ]
    .value_counts()
    .unstack(fill_value=0)
    .assign(total_hits=lambda x: x.neg + x.pos)
    .join(species_taxonomy[["g__", "s__", "taxonomy_string"]])
    .sort_values("total_hits", ascending=False)
    .head(10)
)

In [None]:
(
    gene_meta.loc[lower_confidence_mwas_hit]
    .assign(mwas_ibd_pos_assoc=(gene_stats.ibd_mwas_oddsratio_pc > 1).replace({True: "pos", False: "neg"}))[
        ["species", "mwas_ibd_pos_assoc"]
    ]
    .value_counts()
    .unstack(fill_value=0)
    .assign(total_hits=lambda x: x.neg + x.pos)
    .join(species_taxonomy[["g__", "s__", "taxonomy_string"]])
    .sort_values("total_hits", ascending=False)
    .head(10)
)

In [None]:
_species = '103681'
_species_hit = gene_stats[lambda x: (x.species == _species) & x.index.isin(lower_confidence_mwas_hit)].index
print(gene_x_cog_category_matrix.loc[_species_hit].sum().sort_values(ascending=False).to_frame('tally').join(cog_category_description).head(5))
print(gene_stats.loc[_species_hit].coclust_label_spgc.value_counts().head(5))
gene_meta.loc[_species_hit].join(gene_stats, rsuffix='_').sort_values('ibd_mwas_pvalue').head(20)

In [None]:
gene_stats.loc[lower_confidence_mwas_hit][['species', 'coclust_label_spgc']].value_counts().head(10)

In [None]:
from sklearn.metrics import adjusted_mutual_info_score, adjusted_rand_score
many_strains_species_list = idxwhere(spgc_strain_stats.species.value_counts() > 25)

In [None]:
result = {}
for species in tqdm(many_strains_species_list):
    _gene_list = idxwhere((gene_stats.species == species) & (gene_stats.coclust_label_spgc >= 0))
    result[species] = adjusted_mutual_info_score(gene_stats.loc[_gene_list].coclust_label_spgc, gene_stats.loc[_gene_list].coclust_label_ref)

In [None]:
pd.Series(result)

In [None]:
(
    gene_x_cog_category_matrix.loc[lowish_confidence_mwas_hit]
    .sum()
    .sort_values()
    .to_frame("tally")
    .join(cog_category_description)
    .assign(frac=lambda x: x.tally / x.tally.sum())
    .assign(
        frac_overall=gene_x_cog_category_matrix.sum()
        .sort_values()
        .to_frame("tally")
        .join(cog_category_description)
        .assign(frac=lambda x: x.tally / x.tally.sum())
        .frac
    )
)

In [None]:
# _gene_stats = gene_stats.loc[gene_x_cog_category_matrix.index]


x = gene_x_cog_category_matrix
y = gene_x_cog_category_matrix.index.to_series().isin(low_confidence_mwas_hit)
# z_pos = gene_x_cog_category_matrix.index.to_series().isin(low_confidence_mwas_hit) & (_gene_stats.ibd_mwas_oddsratio_pc > 1)
# z_neg = gene_x_cog_category_matrix.index.to_series().isin(low_confidence_mwas_hit) & (_gene_stats.ibd_mwas_oddsratio_pc < 1)

print("done")

for category in x.columns:
    # category = 'G'
    dh = (
        pd.DataFrame({"hit": y, "in_category": x[category]})
        .value_counts()
        .unstack(fill_value=0)
    )
    # dp = pd.DataFrame({'hit': z_pos, 'in_category': x[category]}).value_counts().unstack(fill_value=0)
    # dn = pd.DataFrame({'hit': z_neg, 'in_category': x[category]}).value_counts().unstack(fill_value=0)
    # odds_ratio_pc = (dp1.loc[True, True] / dp1.loc[True, False]) / (dp1.loc[False, True] / dp1.loc[False, False])
    print(category)
    print(sp.stats.fisher_exact(dh))
    # print(sp.stats.fisher_exact(dp))
    # print(sp.stats.fisher_exact(dn))
    print()

In [None]:
(
    (gene_x_cog_category_matrix)
    .loc[low_confidence_mwas_hit]
    .join(
        (gene_stats.ibd_mwas_oddsratio_pc > 1)
        .to_frame("mwas_ibd_pos_assoc")
        .replace({True: "pos", False: "neg"})
    )
    .groupby("mwas_ibd_pos_assoc")
    .sum()
    .T.join(cog_category_description)
    .assign(total_genes_in_category=gene_x_cog_category_matrix.sum())
)

### Prevalence Class Enrichment

In [None]:
def _assign_prevalence_class(p):
    if p > 0.95:
        return "core"
    elif p > 0.1:
        return "shell"
    elif p < 0.1:
        return "cloud"

In [None]:
cog_category_gene_class_enrichment_test.T

In [None]:
(cog_category_gene_class_enrichment_test

In [None]:
d

In [None]:
d = (
    gene_stats.dropna(subset=["prevalence_spgc"])
    .assign(
        prevalence_class=lambda x: x.prevalence_spgc.dropna().map(
            _assign_prevalence_class
        )
    )
    .join(gene_meta, rsuffix="_")
    .join(gene_x_cog_category_matrix)
)

x = d["prevalence_class"]
y = (
    d.loc[:, cog_category_order]
    .fillna(False)
)

cog_category_gene_class_enrichment_test = {}

for _prevalence_class, _cog_category in tqdm(
    list(product(prevalence_class_order, cog_category_order))
):
    contingency_table = (
        pd.DataFrame(
            dict(
                is_prev_class=(x == _prevalence_class),
                is_cog_category=y[_cog_category],
            )
        )
        .value_counts()
        .unstack()
        .reindex(index=[False, True], columns=[False, True])
        .fillna(0)
    )
    contingency_table_p1 = contingency_table + 1
    _test = sp.stats.fisher_exact(contingency_table)
    cog_category_gene_class_enrichment_test[(_cog_category, _prevalence_class)] = dict(
        pvalue=_test[1],
        gene_count=contingency_table.loc[True, True],
        odds_ratio_pc=(
            (
                contingency_table_p1.loc[True, True]
                / contingency_table_p1.loc[True, False]
            )
            / (
                contingency_table_p1.loc[False, True]
                / contingency_table_p1.loc[False, False]
            )
        ),
    )

cog_category_gene_class_enrichment_test = (
    pd.DataFrame(
        cog_category_gene_class_enrichment_test,
    )
    .T.rename_axis(index=["cog_category", "prevalence_class"])
    .assign(
        negative_log10_pvalue=lambda x: -np.log10(x.pvalue),
        log2_odds_ratio_pc=lambda x: np.log2(x.odds_ratio_pc),
    )
)

In [None]:
cog_category_gene_class_enrichment_test.log2_odds_ratio_pc.unstack(
    "prevalence_class"
).sort_values("core", ascending=False).join(cog_category_description)

In [None]:
x = (
    cog_category_gene_class_enrichment_test.log2_odds_ratio_pc.unstack("prevalence_class")
    .replace({np.inf: np.nan, -np.inf: np.nan})
    .join(cog_category_description)
    .set_index("description")[prevalence_class_order]
    .fillna(0)
)


def _assign_significance_marker(pvalue):
    if pvalue < 0.0001:
        return "*"
    else:
        return ""


# annot = (cog_category_gene_class_enrichment_test.pvalue.map(_assign_significance_marker) + '|' + cog_category_gene_class_enrichment_test.gene_count.astype(int).astype(str)).unstack('prevalence_class')[prevalence_class_order]
annot = (
    cog_category_gene_class_enrichment_test.pvalue.map(_assign_significance_marker)
    .unstack("prevalence_class")
    .join(cog_category_description)
    .set_index("description")[prevalence_class_order]
)
# annot = cog_category_gene_class_enrichment_test.gene_count.unstack('prevalence_class')[prevalence_class_order].astype(int)

_row_order = x["core"].sort_values(ascending=False).index
# x, annot = lib.pandas_util.align_indexes(x, annot)

fig, ax = plt.subplots(figsize=(5, 12))
ax = sns.heatmap(
    x.reindex(_row_order),
    annot=annot.reindex(_row_order),
    fmt="",
    cmap="coolwarm",
    center=0,
    vmin=-3,
    vmax=3,
    cbar_kws=dict(
        use_gridspec=True, location="left", label="log2(odds ratio)", extend="both"
    ),
    ax=ax,
    yticklabels=1,
    xticklabels=1,
    annot_kws=dict(va="center"),
    # norm=mpl.colors.SymLogNorm(linthresh=1e1),
    # center=0,
)

ax.yaxis.set_ticks_position("right")
ax.set_ylabel("")
lib.plot.rotate_yticklabels(rotation=-0, va="center")

In [None]:
d = gene_stats.assign(
    prevalence_ref=lambda x: x.prevalence_ref.fillna(0),
    prevalence_spgc=lambda x: x.prevalence_spgc.fillna(0),
)[lambda x: (x.prevalence_ref > 0) | (x.prevalence_spgc > 0)]

plt.hist2d(
    "prevalence_ref",
    "prevalence_spgc",
    data=d,
    bins=np.linspace(0, 1, num=26),
    norm=mpl.colors.PowerNorm(1 / 5),
)
print(sp.stats.pearsonr(d["prevalence_ref"], d["prevalence_spgc"]))

In [None]:
d = gene_stats.assign(
    prevalence_ref=lambda x: x.prevalence_ref.fillna(0),
    prevalence_spgc=lambda x: x.prevalence_spgc.fillna(0),
)[lambda x: (x.prevalence_ref > 0.1) & (x.prevalence_spgc > 0.1) & (x.prevalence_ref < 0.9) & (x.prevalence_spgc < 0.9)]

plt.hist2d(
    "prevalence_ref",
    "prevalence_spgc",
    data=d,
    bins=np.linspace(0, 1, num=26),
    # norm=mpl.colors.PowerNorm(1 / 5),
)
print(sp.stats.pearsonr(d["prevalence_ref"], d["prevalence_spgc"]))

In [None]:
d = (
    spgc_strain_stats.reset_index()
    .set_index(["species", "strain"])[
        ["spgc_core_gene_tally", "spgc_shell_gene_tally", "spgc_cloud_gene_tally"]
    ]
    .apply(lambda x: x / x.sum(), axis=1)
)
plt.hist(d.spgc_core_gene_tally, label="core", alpha=0.7)
plt.hist(d.spgc_shell_gene_tally, label="shell", alpha=0.7)
plt.hist(d.spgc_cloud_gene_tally, label="cloud", alpha=0.7)

In [None]:
d = (
    ref_strain_stats.reset_index()
    .set_index(["species", "strain"])[
        ["ref_core_gene_tally", "ref_shell_gene_tally", "ref_cloud_gene_tally"]
    ]
    .apply(lambda x: x / x.sum(), axis=1)
)
plt.hist(d.ref_core_gene_tally, label="core", alpha=0.7)
plt.hist(d.ref_shell_gene_tally, label="shell", alpha=0.7)
plt.hist(d.ref_cloud_gene_tally, label="cloud", alpha=0.7)

In [None]:
d = (
    spgc_strain_stats.reset_index()
    .set_index(["species", "strain"])[
        ["spgc_core_gene_tally", "spgc_shell_gene_tally", "spgc_cloud_gene_tally"]
    ]
    .apply(lambda x: x / x.sum(), axis=1)
    .groupby(level="species")
    .median()
)


plt.hist(d.spgc_core_gene_tally, label="core", alpha=0.7)
plt.hist(d.spgc_shell_gene_tally, label="shell", alpha=0.7)
plt.hist(d.spgc_cloud_gene_tally, label="cloud", alpha=0.7)

In [None]:
d = (
    ref_strain_stats.reset_index()
    .set_index(["species", "strain"])[
        ["ref_core_gene_tally", "ref_shell_gene_tally", "ref_cloud_gene_tally"]
    ]
    .apply(lambda x: x / x.sum(), axis=1)
    .groupby(level="species")
    .median()
)
plt.hist(d.ref_core_gene_tally, label="core", alpha=0.7)
plt.hist(d.ref_shell_gene_tally, label="shell", alpha=0.7)
plt.hist(d.ref_cloud_gene_tally, label="cloud", alpha=0.7)

In [None]:
d0 = (
    spgc_strain_stats.reset_index()
    .set_index(["species", "strain"])[
        ["spgc_core_gene_tally", "spgc_shell_gene_tally", "spgc_cloud_gene_tally"]
    ]
    .apply(lambda x: x / x.sum(), axis=1)
    .groupby(level="species")
    .median()
    .join(species_taxonomy)
)
_phylum_list = d0.p__.unique()

fig, axs = plt.subplots(3, sharex=True)
bins = np.linspace(0, 1, num=21)
for partition, ax in zip(
    ["spgc_core_gene_tally", "spgc_shell_gene_tally", "spgc_cloud_gene_tally"], axs
):
    for p__, d1 in d0.groupby("p__"):
        sns.kdeplot(
            d1[partition], color=phylum_palette[p__], ax=ax, label="__nolegend__"
        )
    ax.set_ylabel(partition[len("spgc_") : -len("_gene_tally")])

for p__, _ in d0.groupby("p__"):
    axs[0].plot([], [], color=phylum_palette[p__], label=p__)
axs[0].legend(bbox_to_anchor=(1, 1))


# plt.hist(d.spgc_core_gene_tally, label='core', alpha=0.7)
# plt.hist(d.spgc_shell_gene_tally, label='shell', alpha=0.7)
# plt.hist(d.spgc_cloud_gene_tally, label='cloud', alpha=0.7)

In [None]:
d0 = (
    ref_strain_stats.reset_index()
    .set_index(["species", "strain"])[
        ["ref_core_gene_tally", "ref_shell_gene_tally", "ref_cloud_gene_tally"]
    ]
    .apply(lambda x: x / x.sum(), axis=1)
    .groupby(level="species")
    .median()
    .join(species_taxonomy)
)
_phylum_list = d0.p__.unique()

fig, axs = plt.subplots(3, sharex=True)
bins = np.linspace(0, 1, num=21)
for partition, ax in zip(
    ["ref_core_gene_tally", "ref_shell_gene_tally", "ref_cloud_gene_tally"], axs
):
    for p__, d1 in d0.groupby("p__"):
        sns.kdeplot(
            d1[partition], color=phylum_palette[p__], ax=ax, label="__nolegend__"
        )
    ax.set_ylabel(partition[len("ref_") : -len("_gene_tally")])

for p__, _ in d0.groupby("p__"):
    axs[0].plot([], [], color=phylum_palette[p__], label=p__)
axs[0].legend(bbox_to_anchor=(1, 1))


# plt.hist(d.ref_core_gene_tally, label='core', alpha=0.7)
# plt.hist(d.ref_shell_gene_tally, label='shell', alpha=0.7)
# plt.hist(d.ref_cloud_gene_tally, label='cloud', alpha=0.7)

### "Phylogenetic" Signal

In [None]:
_species_list = idxwhere(ref_strain_stats.species.value_counts() > 100)
print(len(_species_list))
_gene_list = idxwhere(gene_stats.species.isin(_species_list) & (gene_stats.prevalence_ref > 0.2) & (gene_stats.prevalence_ref < 0.8))
print(len(_gene_list))

d = gene_stats.loc[_gene_list]
plt.hist(d.phylogenetic_i_ref, bins=100)
plt.yscale('log')

In [None]:
gene_meta

In [None]:
gene_meta.loc[idxwhere(d.phylogenetic_i_ref < -0.01)]

In [None]:
d = gene_stats.dropna(subset=["phylogenetic_i_ref", "phylogenetic_i_spgc"])

plt.hist2d(
    "phylogenetic_i_ref",
    "phylogenetic_i_spgc",
    data=d,
    bins=50,
    norm=mpl.colors.SymLogNorm(linthresh=1),
)
print(sp.stats.pearsonr(d["phylogenetic_i_ref"], d["phylogenetic_i_spgc"]))

In [None]:
plt.plot(
    gene_stats[lambda x: (x.prevalence_spgc > 0.1) & (x.prevalence_spgc < 0.9)]
    .phylogenetic_i_spgc.sort_values()
    .values
)
plt.plot(
    gene_stats[lambda x: (x.prevalence_spgc > 0.2) & (x.prevalence_spgc < 0.8)]
    .phylogenetic_i_spgc.sort_values()
    .values
)

In [None]:
cog_category_median_phylogenetic_i_ref = {}
cog_category_median_phylogenetic_i_spgc = {}

d = gene_stats.join(gene_x_cog_category_matrix)


for _col in cog_category_order:
    cog_category_median_phylogenetic_i_ref[_col] = d[
        lambda x: (x.prevalence_ref > 0.1)
        & (x.prevalence_ref < 0.9)
        & (x[_col])
        & x.species.isin(["102506"])
    ].phylogenetic_i_ref.median()
    cog_category_median_phylogenetic_i_spgc[_col] = d[
        lambda x: (x.prevalence_spgc > 0.1)
        & (x.prevalence_spgc < 0.9)
        & (x[_col])
        & x.species.isin(["102506"])
    ].phylogenetic_i_spgc.median()

cog_category_median_phylogenetic_i_ref = pd.Series(
    cog_category_median_phylogenetic_i_ref
)
cog_category_median_phylogenetic_i_spgc = pd.Series(
    cog_category_median_phylogenetic_i_spgc
)

In [None]:
pd.DataFrame(
    [cog_category_median_phylogenetic_i_spgc, cog_category_median_phylogenetic_i_ref]
).T.sort_values(0)

In [None]:
sns.violinplot(
    data=gene_stats.join(gene_x_cog_category_matrix)[lambda x: (x.prevalence_spgc > 0.1) & (x.prevalence_spgc < 0.9)],
    x="no_category",
    y="phylogenetic_i_ref",
)