## Preamble

### Project Template

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

### Imports

In [None]:
import os
import subprocess
import sys
import time
from datetime import datetime
from glob import glob
from itertools import chain, product
from tempfile import mkstemp

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import sfacts as sf
import statsmodels.api as sm
import statsmodels.formula.api as smf
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
from lib.dissimilarity import load_dmat_as_pickle
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

In [None]:
import lib.thisproject.data

### Set Style

In [None]:
sns.set_context("talk")
plt.rcParams["figure.dpi"] = 50

In [None]:
genome_type_palette = {"SPGC": "tab:green", "MAG": "tab:orange", "Isolate": "tab:blue"}

## Data Setup

### Metadata

In [None]:
species_list = (
    pd.read_table("meta/species_group.tsv")[
        lambda x: x.species_group_id == "xjin_ucfmt_hmp2"
    ]
    .species_id.astype(str)
    .unique()
)

In [None]:
def parse_taxonomy_string(taxonomy_string):
    values = taxonomy_string.split(";")
    return pd.Series(values, index=["d__", "p__", "c__", "o__", "f__", "g__", "s__"])

In [None]:
species_taxonomy_inpath = "ref/uhgg_genomes_all_v2.tsv"

species_taxonomy = (
    pd.read_table(species_taxonomy_inpath)[lambda x: x.Genome == x.Species_rep]
    .assign(species_id=lambda x: "1" + x.MGnify_accession.str.split("-").str[2])
    .set_index("species_id")
    .Lineage.apply(parse_taxonomy_string)
)
species_taxonomy

### Strain Statistics

In [None]:
filt_stats = []
missing_species = []

_species_list = species_list
# _species_list = ["100003"]

for species in tqdm(_species_list):
    inpath = f"data/group/xjin_ucfmt_hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_new-v22-agg75.spgc-fit.eggnog-strain_gene.strain_meta_for_analysis.tsv"
    if not os.path.exists(inpath):
        missing_species.append(species)
        continue
    data = pd.read_table(inpath, index_col="genome_id").assign(species=species)
    filt_stats.append(data)
filt_stats = pd.concat(filt_stats)

print(
    len(missing_species),
    "out of",
    len(_species_list),
    "species are missing stats.",
)

In [None]:
filt_stats

In [None]:
plt.hist2d(
    "min_ref_gene_raw_diss",
    "min_ref_gene_filt_diss",
    data=filt_stats[lambda x: x.passes_filter].dropna(subset=["ref_nn_genome_id"]),
    bins=50,
    norm=mpl.colors.PowerNorm(1 / 3),
)
None

In [None]:
missing_species

In [None]:
filt_stats.columns

In [None]:
# What fraction of StrainFacts genotypes (with enough positions == 100) have a genotype dissimilarity of less than 1%?
filt_stats[
    lambda x: x.genome_type.isin(["SPGC"])
    & x.passes_geno_positions
    & x.passes_in_sample_list
].min_ref_geno_diss.lt(0.01).mean()

In [None]:
# Set of species for dereplication analysis
d = filt_stats[
    lambda x: x.genome_type.isin(["SPGC"])
    & x.passes_geno_positions
    & x.passes_in_sample_list
]
species_list1 = list(d.species.unique())
len(species_list1)

In [None]:
# Dereplication analysis: how many clusters are solely represented by an SPGC (or MAG/Isolate)?

d = filt_stats[
    lambda x: x.passes_geno_positions
    & (
        (x.genome_type.isin(["SPGC"]) & x.passes_in_sample_list)
        | (x.genome_type.isin(["Isolate", "MAG"]))
    )
]
d.genome_type.value_counts()

In [None]:
d = (
    # Take all genomes with enough genotyped positions, and drop any SPGC
    # genomes that are only found in UCFMT or XJIN samples.
    filt_stats[
        lambda x: x.passes_geno_positions
        & (
            (x.genome_type.isin(["SPGC"]) & x.passes_in_sample_list)
            | (x.genome_type.isin(["Isolate", "MAG"]))
        )
    ]
    # Count the number of genomes of each type in each cluster.
    [["species", "clust", "genome_type"]]
    .value_counts()
    .unstack(fill_value=0)
    # Tag each cluster by it's "best type"
    .assign(best_genome_type=lambda x: x[["Isolate", "MAG", "SPGC"]].idxmax(1))
    # Count for each species the number of clusters with each tag.
    .groupby("species")
    .best_genome_type.value_counts()
    .unstack(fill_value=0)
)

(
    d.join(d.divide(d.sum(1), axis=0), rsuffix="_frac")
    .join(species_taxonomy[["p__", "f__", "g__", "s__"]])
    .sort_values("SPGC", ascending=False)
    .head(20)
)

In [None]:
d = (
    # Take all genomes that pass the full filter.
    filt_stats[lambda x: x.passes_filter]
    # Count the number of genomes of each type in each cluster.
    [["species", "clust", "genome_type"]]
    .value_counts()
    .unstack(fill_value=0)
    # Tag each cluster by it's "best type"
    .assign(best_genome_type=lambda x: x[["Isolate", "MAG", "SPGC"]].idxmax(1))
    # Count for each species the number of clusters with each tag.
    .groupby("species")
    .best_genome_type.value_counts()
    .unstack(fill_value=0)
)

(
    d.join(d.divide(d.sum(1), axis=0), rsuffix="_frac")
    .join(species_taxonomy[["f__", "g__", "s__"]])
    .sort_values("SPGC", ascending=False)
    .head(40)
)

In [None]:
print(
    filt_stats[
        lambda x: x.min_ref_geno_diss.isna() & x.passes_geno_positions
    ].species.value_counts()
)

assert ~filt_stats[lambda x: x.passes_geno_positions].min_ref_geno_diss.isna().any()
assert ~filt_stats[lambda x: x.passes_geno_positions].ref_nn_gene_raw_diss.isna().any()

In [None]:
d = filt_stats[lambda x: x.genome_type.isin(["SPGC"]) & x.passes_filter]
species_list2 = idxwhere((d.species.value_counts() >= 10))

In [None]:
filt_stats.groupby("species").min_ref_gene_raw_diss.apply(
    lambda x: x.isna().mean()
).sort_values().tail(10)

In [None]:
filt_stats[lambda x: x.passes_filter & x.genome_type.isin(["SPGC"])].dropna(
    subset=["ref_nn_genome_id"]
)[
    [
        "min_ref_gene_raw_diss",
        "min_ref_gene_filt_diss",
        "ref_nn_gene_raw_diss",
        "ref_nn_gene_filt_diss",
    ]
].quantile(
    [0.25, 0.5, 0.75]
).T

### Relationship between genotype and gene distance

In [None]:
x = "min_ref_geno_diss"

d0 = filt_stats[lambda x: x.passes_filter].dropna(subset=["ref_nn_genome_id"])

for y in [
    "min_ref_gene_raw_diss",
    "min_ref_gene_filt_diss",
    "ref_nn_gene_raw_diss",
    "ref_nn_gene_filt_diss",
]:
    for genome_set in [["SPGC"], ["Isolate", "MAG"]]:
        d1 = d0[lambda x: x.genome_type.isin(genome_set)]
        print(y, genome_set, sp.stats.spearmanr(d1[x], d1[y]))

In [None]:
d0 = filt_stats[lambda x: x.passes_filter].dropna(subset=["ref_nn_genome_id"])

fit_raw_ref = smf.ols(
    (
        "ref_nn_gene_raw_diss ~ "
        "0 + C(species, Sum) + np.log10(min_ref_geno_diss)"
        # "+ C(species, Sum):genome_type "
        # "+ C(species, Sum):np.log10(min_ref_geno_diss) "
        # "+ genome_type:np.log10(min_ref_geno_diss)"
        # "+ C(species, Sum):genome_type:np.log10(min_ref_geno_diss)"
    ),
    data=d0[lambda x: x.genome_type.isin(["Isolate", "MAG"])],
).fit()
print(fit_raw_ref.aic)

In [None]:
fig, axs = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(7, 9))

for (genome_type, d1), ax in zip(
    d0.assign(predict=lambda x: fit_raw_ref.predict(x)).groupby("genome_type"),
    axs.flatten(),
):
    ax.set_title(genome_type)
    bins = np.linspace(0, 0.5, num=50)
    ax.hist2d(
        "predict",
        "ref_nn_gene_raw_diss",
        data=d1,
        bins=bins,
        norm=mpl.colors.PowerNorm(1 / 2),
    )
    ax.set_aspect(1)
    ax.plot([0, 0.6], [0, 0.6], c="w", lw=1)
fig.tight_layout()
ax.set_xlabel("predicted")
ax.set_ylabel("observed")

In [None]:
d1 = d0[lambda x: x.genome_type == "SPGC"].assign(
    predict=lambda x: fit_raw_ref.predict(x)
)
print(d1["ref_nn_gene_raw_diss"].quantile([0.25, 0.5, 0.75]))
print(d1["predict"].quantile([0.25, 0.5, 0.75]))
print(sp.stats.pearsonr(d1["ref_nn_gene_raw_diss"], d1["predict"]))
print((d1.ref_nn_gene_raw_diss - d1.predict).quantile([0.25, 0.5, 0.75]))
plt.hist(d1.ref_nn_gene_raw_diss - d1.predict, bins=50)

In [None]:
(d1.ref_nn_gene_raw_diss - d1.predict).quantile([0.25, 0.5, 0.75])

## Taxonomic diversity of strains

In [None]:
filt_stats[
    lambda x: x.passes_filter & x.genome_type.isin(["SPGC"])
].species.value_counts().to_frame("num_spgc_strains").join(species_taxonomy)[
    lambda x: x.p__ == "p__Euryarchaeota"
]

In [None]:
d0 = (
    filt_stats[lambda x: x.passes_filter & x.genome_type.isin(["SPGC"])]
    .species.value_counts()
    .to_frame("num_spgc_strains")
    .join(species_taxonomy)
)

fig, ax = plt.subplots(figsize=(5, 20))
ax.invert_yaxis()

_phylum_list = (
    d0.groupby("p__").num_spgc_strains.sum().sort_values(ascending=False).index
)
_phylum_palette = lib.plot.construct_ordered_palette(_phylum_list, cm="rainbow")

y_start = 0
for p__ in _phylum_list:
    d1 = d0[lambda x: x.p__ == p__]
    num_species = len(d1)
    yy = np.arange(y_start, y_start + num_species)
    plt.scatter(
        d1.num_spgc_strains,
        yy,
        color=_phylum_palette[p__],
        s=50,
        marker="x",
        lw=2,
        label=(p__, num_species, d1.num_spgc_strains.sum()),
    )
    y_start += num_species
ax.set_xscale("log")
ax.legend(bbox_to_anchor=(1, 1))

In [None]:
d0 = (
    filt_stats[lambda x: x.passes_filter & x.genome_type.isin(["SPGC"])]
    .species.value_counts()
    .to_frame("num_spgc_strains")
    .join(species_taxonomy)
)

d0.groupby("p__").num_spgc_strains.sum().sort_values(
    ascending=False
)  # sort_values('num_spgc_strains', ascending=False).head(10)

In [None]:
d0 = (
    filt_stats[lambda x: x.passes_filter & x.genome_type.isin(["SPGC"])]
    .groupby("species")
    .apply(
        lambda x: pd.Series(
            dict(
                num_spgc_strains=len(x),
                num_geno_positions=x.num_geno_positions.median(),
            )
        )
    )
    # .to_frame("num_spgc_strains")
    .join(species_taxonomy)
)


for p__ in [
    "p__Firmicutes_A",
    "p__Bacteroidota",
    "p__Proteobacteria",
    "p__Euryarchaeota",
]:
    d1 = d0[lambda x: x.p__ == p__].sort_values("num_spgc_strains", ascending=False)
    print(p__)
    print(d1[["num_spgc_strains", "s__"]].head(8))
    print()

## How many species/genomes for pangenomics?

In [None]:
cog_x_category1.index.is_unique

In [None]:
cog_x_category2 = pd.read_table('ref/cog-20.meta.tsv', encoding='latin1', names=['cog', 'category', 'description', 'preferred_name', 'pathway', '_6', '_7'], index_col='cog').category.apply(lambda x: ''.join(sorted(x)))


In [None]:
cog_x_category1 = gene_meta[lambda x: x.species == '102506'].assign(cog=lambda x: x.eggNOG_OGs.fillna('-').apply(lambda y: y.split('@')[0]))[lambda x: x.cog.str.startswith('COG')][['cog', 'COG_category']].drop_duplicates()#.set_index('cog').COG_category.apply(lambda x: ''.join(sorted(x)))
# pd.DataFrame(dict(a=cog_x_category1, b=cog_x_category2))

In [None]:
cog_x_category1[cog_x_category1.duplicated(subset=['cog'], keep=False)].sort_values('cog')

In [None]:
_species_list = species_list

gene_meta = []
missing_species = []
for species in tqdm(_species_list):
    inpath = f"data/species/sp-{species}/midasdb.gene_meta.tsv"
    if not os.path.exists(inpath):
        missing_species.append(species)
        continue
    data = pd.read_table(inpath, index_col="gene_id").assign(species=species)
    gene_meta.append(data)
gene_meta = pd.concat(gene_meta)

print(
    len(missing_species),
    "out of",
    len(_species_list),
    "species are missing stats.",
)

In [None]:
_species_list = species_list

_result = []
missing_species = []
for species in tqdm(_species_list):
    inpath = f"data/species/sp-{species}/midasdb.gene_x_cog_category.tsv"
    if not os.path.exists(inpath):
        missing_species.append(species)
        continue
    data = pd.read_table(inpath, index_col="gene_id")
    _result.append(data)
gene_x_cog_category = pd.concat(_result)

print(
    len(missing_species),
    "out of",
    len(_species_list),
    "species are missing stats.",
)

gene_x_cog_category_matrix = gene_x_cog_category.reset_index().set_index(['gene_id', 'cog_category']).assign(present=True).present.unstack('cog_category', fill_value=False)

In [None]:
cog_category_description = pd.read_table('ref/cog-20.categories.tsv', names=['cog_category', 'color', 'description'], index_col='cog_category').description
cog_category_description.sort_index()

In [None]:
# Dereplication analysis: how many clusters are solely represented by an SPGC (or MAG/Isolate)?

d = filt_stats[
    lambda x: x.passes_filter
    & (
        (x.genome_type.isin(["SPGC"]) & x.passes_in_sample_list)
        | (x.genome_type.isin(["Isolate", "MAG"]))
    )
]
print("Num strains of each type:", d.genome_type.value_counts())

species_with_gt10_spgc_strains = idxwhere(
    d[lambda x: x.genome_type.isin(["SPGC"])].species.value_counts() >= 10
)
print("Num species with >=10 genomes:", len(species_with_gt10_spgc_strains))

print(
    "Num strains of each type in pangenomics species:",
    d[
        lambda x: x.species.isin(species_with_gt10_spgc_strains)
    ].genome_type.value_counts(),
)

In [None]:
d[lambda x: x.species.isin(species_with_gt10_spgc_strains)][
    ["species", "genome_type"]
].value_counts().unstack("genome_type", fill_value=0).SPGC.quantile(
    [0.25, 0.5, 0.75], interpolation="nearest"
)

In [None]:
d[lambda x: x.species.isin(species_with_gt10_spgc_strains)][
    ["species", "genome_type"]
].value_counts().unstack("genome_type", fill_value=0).assign(
    spgc_ratio=lambda x: x.SPGC / (x.Isolate + x.MAG)
).sort_values(
    "spgc_ratio"
)

## Pangenomics (TODO: this doesn't really belong in this notebook)

In [None]:
prevalence = []

for species in tqdm(species_with_gt10_spgc_strains):
    spgc_inpath = f"data/group/xjin_ucfmt_hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_new-v22-agg75.spgc-fit.uhgg-strain_gene.prevalence-hmp2.tsv"
    ref_inpath = (
        f"data/species/sp-{species}/midasdb.gene75_new.uhgg-strain_gene.prevalence.tsv"
    )
    spgc_prev = pd.read_table(
        spgc_inpath, names=["gene_id", "prevalence"], index_col="gene_id"
    ).prevalence
    ref_prev = pd.read_table(
        ref_inpath, names=["gene_id", "prevalence"], index_col="gene_id"
    ).prevalence
    data = pd.DataFrame(dict(spgc=spgc_prev, ref=ref_prev)).assign(species=species)
    prevalence.append(data)
prevalence = pd.concat(prevalence).fillna(0)

In [None]:
print(sp.stats.pearsonr(prevalence.ref, prevalence.spgc))
plt.hist2d("ref", "spgc", data=prevalence, bins=20, norm=mpl.colors.LogNorm())
plt.colorbar()
None

In [None]:
plt.hist(prevalence.spgc[lambda x: x > 0], bins=np.linspace(0, 1, num=20))
None

In [None]:
prevalence_class = prevalence.spgc.map(lambda x: np.where(x > 0.9, "core", np.where(x > 0.15, "shell", "cloud")))

In [None]:
d0 = (
    prevalence_class.to_frame("prevalence_class")
    .join(gene_x_cog_category_matrix)
    .assign(
        cloud=lambda x: x.prevalence_class == "cloud",
        shell=lambda x: x.prevalence_class == "shell",
        core=lambda x: x.prevalence_class == "core",
    )
)

result = []
for _prevalence_class, _cog_category in tqdm(list(product(['core', 'shell', 'cloud'], gene_x_cog_category_matrix.columns))):
    d1 = d0[[_prevalence_class, _cog_category]].value_counts().unstack().reindex(index=[True, False], columns=[True, False]).fillna(0)
    d1_pc = d1 + 1
    log_oddsratio = np.log2((d1_pc.loc[True, True] / d1_pc.loc[True, False]) / (d1_pc.loc[False, True] / d1_pc.loc[False, False]))
    result.append((_prevalence_class, _cog_category, log_oddsratio, sp.stats.fisher_exact(d1)[1]))
prevalence_class_cog_category_enrichment = pd.DataFrame(result, columns=['prevalence_class', 'cog_category', 'log2_oddsratio', 'pvalue']).set_index(['prevalence_class', 'cog_category'])

In [None]:
d = prevalence_class_cog_category_enrichment

d_oddsr = d.log2_oddsratio.unstack('prevalence_class')
d_signf = d.pvalue.map(lambda x: np.where(x < 0.05, '·', '')).unstack('prevalence_class')

prevalence_class_order = ['core', 'shell', 'cloud']
cog_category_order = d_oddsr['core'].sort_values(ascending=False).index

fig, ax = plt.subplots(figsize=(10, 10))
sns.heatmap(d_oddsr.loc[cog_category_order, prevalence_class_order], norm=mpl.colors.PowerNorm(1, vmin=-4, vmax=+4), cmap='coolwarm', ax=ax)
# Annotations (because seaborn annotations are failing)
for (i, _cog_category), (j, _prevalence_class) in product(enumerate(cog_category_order), enumerate(prevalence_class_order)):
    ax.annotate(d_signf.loc[_cog_category, _prevalence_class], xy=(j + 0.5, i + 0.5), ha='center', va='center')

In [None]:
prevalence_class_cog_category_enrichment.loc['core'].sort_values(['log2_oddsratio'], ascending=False).join(cog_category_description)

In [None]:
prevalence_class_cog_category_enrichment.loc['shell'].sort_values(['log2_oddsratio'], ascending=False).join(cog_category_description)

In [None]:
prevalence_class_cog_category_enrichment.loc['cloud'].sort_values(['log2_oddsratio'], ascending=False).join(cog_category_description)

In [None]:
spgc_prevalence_class_tally = []
ref_prevalence_class_tally = []

for species in tqdm(species_with_gt10_spgc_strains):
    spgc_inpath = f"data/group/xjin_ucfmt_hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_new-v22-agg75.spgc-fit.uhgg-strain_gene.prevalence_class_fraction-hmp2.tsv"
    spgc_data = (
        pd.read_table(spgc_inpath)
        .assign(species=species)
        .rename(columns={"strain": "genome_id"})
        .astype({"genome_id": str})
        .set_index(["species", "genome_id"])
    )
    spgc_prevalence_class_tally.append(spgc_data)

    # TODO: ref_prevalence_class_tally
    ref_inpath = f"data/species/sp-{species}/midasdb.gene75_new.uhgg-strain_gene.prevalence_class_fraction.tsv"
    ref_data = (
        pd.read_table(ref_inpath)
        .assign(species=species)
        .rename(columns={"strain": "genome_id"})
        .astype({"genome_id": str})
        .set_index(["species", "genome_id"])
    )
    ref_prevalence_class_tally.append(ref_data)

# Filter SPGC
spgc_prevalence_class_tally = (
    pd.concat(spgc_prevalence_class_tally)
    .fillna(0)
    .loc[
        filt_stats[
            lambda x: x.genome_type.isin(["SPGC"])
            & x.species.isin(species_with_gt10_spgc_strains)
        ]
        .reset_index()
        .set_index(["species", "genome_id"])
        .passes_filter
    ]
)

# Filter Ref
ref_prevalence_class_tally = (
    pd.concat(ref_prevalence_class_tally)
    .fillna(0)
    .loc[
        filt_stats[
            lambda x: x.genome_type.isin(["MAG", "Isolate"])
            & x.species.isin(species_with_gt10_spgc_strains)
        ]
        .reset_index()
        .set_index(["species", "genome_id"])
        .passes_filter
    ]
)

In [None]:
spgc_prevalence_class_frac = spgc_prevalence_class_tally.divide(
    spgc_prevalence_class_tally.sum(1), axis=0
)  # .groupby("species").median()
ref_prevalence_class_frac = ref_prevalence_class_tally.divide(
    ref_prevalence_class_tally.sum(1), axis=0
)  # .groupby("species").median()

In [None]:
d1 = spgc_prevalence_class_frac.groupby("species").median()

for frac in ["core", "shell", "cloud"]:
    plt.hist(d1[frac], label=frac, alpha=0.7)
plt.legend()

d1.groupby("species").median().quantile([0.25, 0.5, 0.75])

In [None]:
d1 = ref_prevalence_class_frac.groupby("species").median()

for frac in ["core", "shell", "cloud"]:
    plt.hist(d1[frac], label=frac, alpha=0.7)
plt.legend()

d1.groupby("species").median().quantile([0.25, 0.5, 0.75])

In [None]:
d0 = pd.DataFrame(
    dict(
        ref=ref_prevalence_class_frac.groupby("species").median().stack(),
        spgc=spgc_prevalence_class_frac.groupby("species").median().stack(),
    )
).rename_axis(index=["species", "pangenome_fraction"])
for pangenome_fraction, d1 in d0.groupby("pangenome_fraction"):
    plt.scatter("ref", "spgc", data=d1, s=10, alpha=0.7)

In [None]:
(
    spgc_prevalence_class_frac.groupby("species")
    .median()
    .join(ref_prevalence_class_frac.groupby("species").median(), rsuffix="_ref")
    .assign(
        total_num_ref_genomes=filt_stats[
            lambda x: x.genome_type.isin(["MAG", "Isolate"])
        ]["species"].value_counts()
    )
    .join(species_taxonomy[["f__", "g__", "s__"]])
    .sort_values("cloud", ascending=False)
    .head(20)
)

In [None]:
(
    spgc_prevalence_class_frac.groupby("species")
    .median()
    .join(ref_prevalence_class_frac.groupby("species").median(), rsuffix="_ref")
    .assign(
        total_num_ref_genomes=filt_stats[
            lambda x: x.genome_type.isin(["MAG", "Isolate"])
        ]["species"].value_counts()
    )
    .join(species_taxonomy[["f__", "g__", "s__"]])
    .sort_values("cloud", ascending=True)
    .head(20)
)

In [None]:
d2 = (
    spgc_prevalence_class_frac.groupby("species")
    .median()
    .rename_axis(columns="pangenome_fraction")
    .stack()
    .to_frame("frac")
    .join(species_taxonomy, on="species")
    .reset_index()
)

sns.stripplot(x="pangenome_fraction", hue="p__", y="frac", data=d2, dodge=True)
lib.plot.rotate_xticklabels()
plt.legend(bbox_to_anchor=(1, 1))

In [None]:
d = (
    spgc_prevalence_class_frac.groupby("species")
    .median()
    .join(ref_prevalence_class_frac.groupby("species").median(), rsuffix="_ref")
    .assign(
        total_num_ref_genomes=filt_stats[
            lambda x: x.genome_type.isin(["MAG", "Isolate"])
        ]["species"].value_counts(),
        total_num_spgc_genomes=filt_stats[lambda x: x.genome_type.isin(["SPGC"])][
            "species"
        ].value_counts(),
    )
    .join(species_taxonomy[["p__", "f__", "g__", "s__"]])
)

_phylum_palette = lib.plot.construct_ordered_palette(d.p__.unique(), cm="tab10")


import mpltern

fig = plt.figure(figsize=(15, 15))
ax = fig.add_subplot(projection="ternary", ternary_sum=100.0)
ax.grid()


ax.set_tlabel("Core (%)")
ax.set_llabel("Shell (%)")
ax.set_rlabel("Cloud (%)")

ax.set_tlim(20, 100)
ax.set_llim(10, 62)
ax.set_rlim(0, 35)

for p__, d1 in d.groupby("p__"):
    ax.scatter(
        "core",
        "shell",
        "cloud",
        data=d1,
        s=d1["total_num_spgc_genomes"],
        marker="o",
        lw=2,
        edgecolor=_phylum_palette[p__],
        facecolor="none",
        alpha=0.85,
    )
    ax.scatter(
        [],
        [],
        [],
        edgecolor=_phylum_palette[p__],
        label=p__,
        lw=2,
        facecolor="none",
    )

ax.legend(loc="upper left")

In [None]:
# phylum_palette = lib.plot.construct_ordered_palette(species_taxonomy.p__.unique(), cm='tab10')

d = (
    spgc_prevalence_class_frac.groupby("species")
    .median()
    .join(ref_prevalence_class_frac.groupby("species").median(), rsuffix="_ref")
    .assign(
        total_num_ref_genomes=filt_stats[
            lambda x: x.genome_type.isin(["MAG", "Isolate"])
        ]["species"].value_counts()
    )
    .join(species_taxonomy[["f__", "g__", "s__"]])
)

import mpltern

fig = plt.figure(figsize=(5, 5))
ax = fig.add_subplot(projection="ternary", ternary_sum=100.0)

ax.set_tlabel("Core (%)")
ax.set_llabel("Shell (%)")
ax.set_rlabel("Cloud (%)")

# ax.grid()

ax.scatter("core_ref", "shell_ref", "cloud_ref", data=d, s=10)

In [None]:
# phylum_palette = lib.plot.construct_ordered_palette(species_taxonomy.p__.unique(), cm='tab10')

d = (
    spgc_prevalence_class_frac.groupby("species")
    .median()
    .rename_axis(columns="pangenome_fraction")
    # .stack()
    # .to_frame("frac")
    .join(species_taxonomy, on="species")
    # .assign(phylum_color=lambda x: x.p__.map(phylum_palette))
    # .reset_index()
)

import mpltern

fig = plt.figure(figsize=(5, 5))
ax = fig.add_subplot(projection="ternary", ternary_sum=100.0)

ax.set_tlabel("Core (%)")
ax.set_llabel("Shell (%)")
ax.set_rlabel("Cloud (%)")

# ax.grid()

ax.scatter("core", "shell", "cloud", data=d, s=10)

In [None]:
_species_list = species_list

morans_i = []
missing_species = []
for species in tqdm(_species_list):
    inpath = f"data/group/xjin_ucfmt_hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_new-v22-agg75.spgc-fit.uhgg-strain_gene.morans_i.tsv"
    if not os.path.exists(inpath):
        missing_species.append(species)
        continue
    data = pd.read_table(inpath, index_col="gene_id")
    morans_i.append(data)
morans_i = pd.concat(morans_i)

print(
    len(missing_species),
    "out of",
    len(_species_list),
    "species are missing stats.",
)

In [None]:
inpath

In [None]:
d = filt_stats[lambda x: x.genome_type.isin(["SPGC"]) & x.passes_filter]
species_list3 = idxwhere((d.species.value_counts() >= 20))
len(species_list3)

In [None]:
d = morans_i.join(gene_meta.species)[lambda x: x.species.isin(species_list2)]

plt.hist2d('ref', 'spgc', data=d.dropna(), bins=20, norm=mpl.colors.PowerNorm(1/3))
sp.stats.pearsonr(d.dropna()['ref'], d.dropna()['spgc'])

In [None]:
d = morans_i.join(gene_meta.species)[lambda x: x.species.isin(species_list3)]

plt.hist2d('ref', 'spgc', data=d.dropna(), bins=20, norm=mpl.colors.PowerNorm(1/5))
sp.stats.pearsonr(d.dropna()['ref'], d.dropna()['spgc'])

In [None]:
plt.hist(d.spgc.dropna(), bins=100, alpha=0.5)
plt.hist(d.ref.dropna(), bins=100, alpha=0.5)

plt.yscale('log')
None

In [None]:
d = gene_meta.join(gene_x_cog_category_matrix).assign(cmi=morans_i.ref, prevalence=prevalence.ref).fillna({'cmi': 0})
d

In [None]:
# NOTE: Can take 15 or more minutes to run.

d0 = pd.DataFrame(dict(cmi=morans_i.spgc, prevalence=prevalence.spgc)).assign(species=gene_meta.species).join(gene_x_cog_category_matrix).dropna()

results = {}
for species in tqdm(species_list3):
    d1 = d0[lambda x: (x.species == species) & (x.prevalence > 0.15) & (x.prevalence < 0.9)]
    for category in gene_x_cog_category_matrix.columns:
        if (d1[category].mean() == 0) or (d1[category].mean() == 1):
            continue
        fit = smf.rlm(f'cmi ~ cr(prevalence, 4) + {category}', data=d1).fit()
        results[(species, category)] = [fit.params[f'{category}[T.True]'], fit.pvalues[f'{category}[T.True]']]

for category in tqdm(gene_x_cog_category_matrix.columns):
    d1 = d0[lambda x: (x.species.isin(species_list3)) & (x.prevalence > 0.15) & (x.prevalence < 0.9)]
    fit = smf.rlm(f'cmi ~ cr(prevalence, 4) + {category} + species', data=d1).fit()
    results[('overall', category)] = [fit.params[f'{category}[T.True]'], fit.pvalues[f'{category}[T.True]']]

spgc_species_by_cog_category_cdi_bias_test = pd.DataFrame(results.values(), index=results.keys(), columns=['coef', 'pvalue'])

In [None]:
# NOTE: Can take 15 or more minutes to run.

d0 = pd.DataFrame(dict(cmi=morans_i.ref, prevalence=prevalence.ref)).assign(species=gene_meta.species).join(gene_x_cog_category_matrix).dropna()

results = {}
for species in tqdm(species_list3):
    d1 = d0[lambda x: (x.species == species) & (x.prevalence > 0.15) & (x.prevalence < 0.9)]
    for category in gene_x_cog_category_matrix.columns:
        if (d1[category].mean() == 0) or (d1[category].mean() == 1):
            continue
        fit = smf.rlm(f'cmi ~ cr(prevalence, 4) + {category}', data=d1).fit()
        results[(species, category)] = [fit.params[f'{category}[T.True]'], fit.pvalues[f'{category}[T.True]']]

for category in tqdm(gene_x_cog_category_matrix.columns):
    d1 = d0[lambda x: (x.species.isin(species_list3)) & (x.prevalence > 0.15) & (x.prevalence < 0.9)]
    fit = smf.rlm(f'cmi ~ cr(prevalence, 4) + {category} + species', data=d1).fit()
    results[('overall', category)] = [fit.params[f'{category}[T.True]'], fit.pvalues[f'{category}[T.True]']]

ref_species_by_cog_category_cdi_bias_test = pd.DataFrame(results.values(), index=results.keys(), columns=['coef', 'pvalue'])

In [None]:
d0 = spgc_species_by_cog_category_cdi_bias_test
d1_coef = d0.coef.unstack(fill_value=0)
d1_signf = d0.pvalue.unstack(fill_value=1.0).map(lambda p: np.where(p < 0.05, '•', ''))

cmi_category_order = list(d1_coef.mean().sort_values().index)
cmi_species_order = list(np.square(d1_coef).mean(1).sort_values(ascending=False).drop('overall').index) + ['overall']

d2_coef = d1_coef.loc[cmi_species_order, cmi_category_order]
d2_signf = d1_signf.loc[cmi_species_order, cmi_category_order]

fig, ax = plt.subplots(figsize=(20, 20))
# sns.heatmap(d2_coef)
# sns.heatmap(d1_coef.loc[cmi_species_order, cmi_category_order], annot=d1_signf.loc[cmi_species_order, cmi_category_order], fmt='', norm=mpl.colors.SymLogNorm(linthresh=0.001, vmin=-1e5, vmax=1e5), xticklabels=1, yticklabels=1, cmap='coolwarm', ax=ax)
sns.heatmap(d2_coef, norm=mpl.colors.SymLogNorm(linthresh=0.001, vmin=-.1, vmax=.1), xticklabels=1, yticklabels=1, cmap='PuOr', ax=ax)

# Annotations (because seaborn annotations are failing)
for (i, species), (j, category) in product(enumerate(cmi_species_order), enumerate(cmi_category_order)):
    ax.annotate(d2_signf.loc[species, category], xy=(j + 0.5, i + 0.5), ha='center', va='center')

In [None]:
spgc_species_by_cog_category_cdi_bias_test.loc['overall'].reindex(cmi_category_order)

In [None]:
cog_category_description.reindex(cmi_category_order)

In [None]:
n = 20
r = 0.1
offset = 1

p = int(round(n * r))
q = int(round(n * (1 - r)))
print(p, q)
x = [0]*p + [1]*q
y = [0]*(p - offset) + [1]*(q + offset)
print(x)
print(y)
sp.stats.pearsonr(x, y)

In [None]:
','.join(species_list3)

In [None]:
d0 = ref_species_by_cog_category_cdi_bias_test
d1_coef = d0.coef.unstack(fill_value=0)
d1_signf = d0.pvalue.unstack(fill_value=1.0).map(lambda p: np.where(p < 0.05, '•', ''))

d2_coef = d1_coef.loc[cmi_species_order, cmi_category_order]
d2_signf = d1_signf.loc[cmi_species_order, cmi_category_order]

fig, ax = plt.subplots(figsize=(20, 20))
# sns.heatmap(d2_coef)
# sns.heatmap(d1_coef.loc[cmi_species_order, cmi_category_order], annot=d1_signf.loc[cmi_species_order, cmi_category_order], fmt='', norm=mpl.colors.SymLogNorm(linthresh=0.001, vmin=-1e5, vmax=1e5), xticklabels=1, yticklabels=1, cmap='coolwarm', ax=ax)
sns.heatmap(d2_coef, norm=mpl.colors.SymLogNorm(linthresh=0.001, vmin=-.1, vmax=.1), xticklabels=1, yticklabels=1, cmap='PuOr', ax=ax)

# Annotations (because seaborn annotations are failing)
for (i, species), (j, category) in product(enumerate(cmi_species_order), enumerate(cmi_category_order)):
    ax.annotate(d2_signf.loc[species, category], xy=(j + 0.5, i + 0.5), ha='center', va='center')

In [None]:
_species_list = species_list

co_clust = []
missing_species = []
for species in tqdm(_species_list):
    ref_inpath = f"data/species/sp-{species}/midasdb.gene75_new.uhgg-strain_gene.gene_clust-t10.tsv"
    spgc_inpath = f"data/group/xjin_ucfmt_hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_new-v22-agg75.spgc-fit.uhgg-strain_gene.gene_clust-t10.tsv"
    if not os.path.exists(spgc_inpath):
        missing_species.append(species)
        continue
    ref_data = pd.read_table(ref_inpath, names=['gene_id', 'cluster'], index_col="gene_id").cluster
    spgc_data = pd.read_table(spgc_inpath, names=['gene_id', 'cluster'], index_col="gene_id").cluster
    co_clust.append(pd.DataFrame(dict(ref_clust=ref_data, spgc_clust=spgc_data)).assign(species=species))
co_clust = pd.concat(co_clust).fillna({'spgc_clust': -5, 'ref_clust': -5}).astype({'spgc_clust': int, 'ref_clust': int})

print(
    len(missing_species),
    "out of",
    len(_species_list),
    "species are missing stats.",
)

In [None]:
d = (co_clust[lambda x: (x.spgc_clust >= 0) & x.species.isin(species_list3)][['species', 'spgc_clust']].value_counts()[lambda x: (x > 1)])
clust_list = d.index
len(d), d.sum(), d.mean()

In [None]:
d.groupby('species').apply(len).median()

In [None]:
gene_x_module = gene_meta.KEGG_Module.dropna().str.split(',').explode()#[lambda x: x.str.startswith('map')]
gene_x_module.value_counts().head(20)

In [None]:
gene_x_pathway = gene_meta.KEGG_inpathway.dropna().str.split(',').explode()[lambda x: x.str.startswith('map')]
gene_x_pathway.value_counts().head(20)

In [None]:
d0 = (
    co_clust[lambda x: (x.spgc_clust >= 0) & (x.index.isin(gene_x_module.index))]
    .join(gene_x_module)
    .assign(
        # ref_clust=lambda x: x.species + "-" + x.ref_clust.astype(str),
        spgc_clust=lambda x: x.species + "-" + x.spgc_clust.astype(str),
    )
)
d0

In [None]:
thresh = 1

obs_count_module_multihit = (
    # Count the total number of clusters with multiple hits to the same KEGG_Module
    d0[["spgc_clust", "KEGG_Module"]]
    .value_counts()
    .gt(thresh)
    .groupby('spgc_clust')
    .any()
    .sum()
)
perm_count_module_multihit = []
for i in tqdm(range(999)):
    perm_count_module_multihit.append(
        d0.assign(
            # Permute clusters within species.
            spgc_clust=lambda x: x.groupby("species").spgc_clust.sample(frac=1).values
        )[["spgc_clust", "KEGG_Module"]]
    .value_counts()
    .gt(thresh)
    .groupby('spgc_clust')
    .any()
    .sum()
    )
perm_count_module_multihit = np.array(perm_count_module_multihit)
plt.hist(perm_count_module_multihit)
plt.axvline(obs_count_module_multihit)

print(perm_count_module_multihit.mean(), perm_count_module_multihit.std(), obs_count_module_multihit, obs_count_module_multihit / perm_count_module_multihit.mean())

In [None]:
thresh = 2

obs_count_module_multihit = (
    # Count the total number of clusters with multiple hits to the same KEGG_Module
    d0[["spgc_clust", "KEGG_Module"]]
    .value_counts()
    .gt(thresh)
    .groupby('spgc_clust')
    .any()
    .sum()
)
perm_count_module_multihit = []
for i in tqdm(range(999)):
    perm_count_module_multihit.append(
        d0.assign(
            # Permute clusters within species.
            spgc_clust=lambda x: x.groupby("species").spgc_clust.sample(frac=1).values
        )[["spgc_clust", "KEGG_Module"]]
    .value_counts()
    .gt(thresh)
    .groupby('spgc_clust')
    .any()
    .sum()
    )
perm_count_module_multihit = np.array(perm_count_module_multihit)
plt.hist(perm_count_module_multihit)
plt.axvline(obs_count_module_multihit)

print(perm_count_module_multihit.mean(), perm_count_module_multihit.std(), obs_count_module_multihit, obs_count_module_multihit / perm_count_module_multihit.mean())

In [None]:
thresh = 3

obs_count_module_multihit = (
    # Count the total number of clusters with multiple hits to the same KEGG_Module
    d0[["spgc_clust", "KEGG_Module"]]
    .value_counts()
    .gt(thresh)
    .groupby('spgc_clust')
    .any()
    .sum()
)
perm_count_module_multihit = []
for i in tqdm(range(999)):
    perm_count_module_multihit.append(
        d0.assign(
            # Permute clusters within species.
            spgc_clust=lambda x: x.groupby("species").spgc_clust.sample(frac=1).values
        )[["spgc_clust", "KEGG_Module"]]
    .value_counts()
    .gt(thresh)
    .groupby('spgc_clust')
    .any()
    .sum()
    )
perm_count_module_multihit = np.array(perm_count_module_multihit)
plt.hist(perm_count_module_multihit)
plt.axvline(obs_count_module_multihit)

print(perm_count_module_multihit.mean(), perm_count_module_multihit.std(), obs_count_module_multihit, obs_count_module_multihit / perm_count_module_multihit.mean())

In [None]:
_species_list = species_list

donor_comparison = []
missing_species = []
for species in tqdm(_species_list):
    inpath = f"data/group/xjin_ucfmt_hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_new-v22-agg75.spgc-fit.uhgg-strain_gene-ucfmt.tsv"
    if not os.path.exists(inpath):
        missing_species.append(species)
        continue
    donor_comparison.append(pd.read_table(inpath, index_col="gene_id"))

donor_comparison = pd.concat(donor_comparison).dropna(subset=['D0097', 'D0044']).assign(
        label=lambda x: (x.D0097 * 2 + x.D0044 * 3).map(
            {0: "neither", 5: "both", 2: "d97", 3: "d44"}
        ))
# mwas = pd.concat(mwas).assign(
#     log2_oddsratio_pc_ibd=lambda x: np.log2(x.oddsratio_pc_ibd),
#     neg_log10_pvalue=lambda x: -np.log10(x.fisher_exact_pvalue_ibd),
#     num_subjects_present=lambda x: x["present-nonIBD"] + x["present-IBD"],
#     num_subjects_absent=lambda x: x["absent-nonIBD"] + x["absent-IBD"],
#     num_subjects_total=lambda x: x.num_subjects_present + x.num_subjects_absent,
#     prevalence=lambda x: x.num_subjects_present
#     / x.num_subjects_total,
# )

print(
    len(missing_species),
    "out of",
    len(_species_list),
    "species are missing stats.",
)

In [None]:
d = donor_comparison.groupby(gene_meta.species).label.value_counts().unstack()
d.apply(lambda x: x / x.sum(), axis=1).quantile([0.25, 0.5, 0.75]).T

In [None]:
_species_list = species_list

mwas = []
missing_species = []
for species in tqdm(_species_list):
    inpath = f"data/group/xjin_ucfmt_hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_new-v22-agg75.spgc-fit.uhgg-strain_gene.hmp2_mwas-f30-n1.tsv"
    if not os.path.exists(inpath):
        missing_species.append(species)
        continue
    mwas.append(pd.read_table(inpath, index_col="gene_id").assign(species=species))
mwas = pd.concat(mwas).assign(
    log2_oddsratio_pc_ibd=lambda x: np.log2(x.oddsratio_pc_ibd),
    neg_log10_pvalue=lambda x: -np.log10(x.fisher_exact_pvalue_ibd),
    num_subjects_present=lambda x: x["present-nonIBD"] + x["present-IBD"],
    num_subjects_absent=lambda x: x["absent-nonIBD"] + x["absent-IBD"],
    num_subjects_total=lambda x: x.num_subjects_present + x.num_subjects_absent,
    prevalence=lambda x: x.num_subjects_present
    / x.num_subjects_total,
)

print(
    len(missing_species),
    "out of",
    len(_species_list),
    "species are missing stats.",
)

In [None]:
species_taxonomy.loc['102492']

In [None]:
gene_meta[['eggNOG_OGs', 'COG_category']][lambda x: ~x.eggNOG_OGs.fillna('').str.startswith('COG')].head(100).tail(50)

In [None]:
d = mwas.assign(donor=donor_comparison.label).assign(prevalence_class=prevalence_class)[
    lambda x: (x.num_subjects_present >= 10)
    & (x.num_subjects_absent >= 10)
    & x.donor.isin(["d97", "d44"])
    & (x.prevalence_class == 'shell')
].assign(fdr=lambda x: fdrcorrection(x.fisher_exact_pvalue_ibd)[1]).sort_values(
    "fisher_exact_pvalue_ibd"
)

plt.scatter('log2_oddsratio_pc_ibd', 'fisher_exact_pvalue_ibd', c='fdr', data=d, lw=1, facecolor='none', marker='o', cmap='viridis_r')
plt.yscale('log')
plt.gca().invert_yaxis()
plt.colorbar()

In [None]:
d[lambda x: x.fisher_exact_pvalue_ibd < 1e-2][
    [
        "fisher_exact_pvalue_ibd",
        "log2_oddsratio_pc_ibd",
        "num_subjects_total",
        "prevalence",
        "donor",
    ]
].join(
    gene_meta[
        [
            "species",
            "PFAMs",
            "eggNOG_OGs",
            "COG_category",
            "Description",
            "Preferred_name",
        ]
    ]
)

In [None]:
from statsmodels.graphics.gofplots import qqplot

d0 = mwas.assign(donor=donor_comparison.label).assign(prevalence_class=prevalence_class)[
    lambda x: (x.num_subjects_present >= 10)
    & (x.num_subjects_absent >= 10)
    & x.donor.isin(["d97", "d44"])
    & (x.prevalence_class == 'shell')
].assign(fdr=lambda x: fdrcorrection(x.fisher_exact_pvalue_ibd)[1]).sort_values(
    "fisher_exact_pvalue_ibd"
)

d1 = d0.join(gene_x_cog_category_matrix)

fig, ax = plt.subplots(figsize=(10, 10))

cog_category_list = gene_x_cog_category_matrix.loc[d0.index].sum().sort_values(ascending=False).head(10).index

for cat in cog_category_list:
    d2 = d1[d1[cat]]
    ax.scatter(np.linspace(0, 1, num=d2.shape[0] + 1)[1:], sp.stats.uniform.ppf(d2.fisher_exact_pvalue_ibd), label=cat, s=20)
ax.plot([0, 1], [0, 1], color='k')
ax.legend(bbox_to_anchor=(1, 1), markerscale=5)
ax.set_yscale('log')
ax.set_xscale('log')
# plt.xlim(right=1e-1)
# plt.ylim(top=1e-1)

In [None]:
d1[lambda x: (x.G) & (x.fisher_exact_pvalue_ibd < 1e-2)].oddsratio_pc_ibd.to_frame().join(gene_meta).sort_values('oddsratio_pc_ibd')

In [None]:
plt.hist2d('log2_oddsratio_pc_ibd', 'neg_log10_pvalue', data=mwas, norm=mpl.colors.PowerNorm(1/10), bins=100)
None

In [None]:
d = mwas[
    lambda x: (x.num_subjects_present > 20)
    & (x.num_subjects_absent > 20)
].sort_values("fisher_exact_pvalue_ibd")
d.assign(lambda x: x..value_counts()
# fdrcorrection(d.fisher_exact_pvalue_ibd)[1]

In [None]:
fdr_by_species.assign(signif=lambda x: x.fdr < 0.05)[['species', 'signif']].value_counts().unstack(fill_value=0).sort_values(True, ascending=False).join(species_taxonomy).head(20)

In [None]:
mwas = mwas.assign(fdr=lambda x: x[test_filter].groupby('species').fisher_exact_pvalue_ibd.pipe()
plt.hist2d('log2_oddsratio_pc_ibd', 'neg_log10_pvalue', data=mwas[test_filter], norm=mpl.colors.PowerNorm(1/10), bins=100)
None

In [None]:
bins = np.logspace(-5, 0, num=40)
plt.hist(mwas[test_filter].fisher_exact_pvalue_ibd, bins=bins)
plt.plot(bins[1:], (bins[1:] - bins[:-1]) * mwas[test_filter].shape[0])
plt.xscale('log')
plt.yscale('log')

In [None]:
plt.plot(np.sort(fdrcorrection(mwas[test_filter].fisher_exact_pvalue_ibd)[1]))