## Preamble

### Template Utils

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

### Imports

In [None]:
import os
import subprocess
import time
import warnings
from itertools import chain, product
from tempfile import mkstemp

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import sfacts as sf
import statsmodels as sm
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.spatial.distance import pdist, squareform
from sklearn.cluster import AgglomerativeClustering
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
import lib.thisproject.data
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

### Set Style

In [None]:
sns.set_context("talk")
plt.rcParams["figure.dpi"] = 50

### Papermill parameters

In [None]:
# This cell is tagged "parameters" for papermill.
# See <https://papermill.readthedocs.io/en/latest/usage-parameterize.html#how-parameters-work> for some gotchas.
# NOTE: *ALL* parameters should be passed to papermill. Values set here are only for prototyping.
species_id = "102506"

species_taxonomy_inpath = f"ref/gtpro/species_taxonomy_ext.tsv"
sample_to_spgc_inpath = f"data/group/xjin_ucfmt_hmp2/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.spgc_ss-all.strain_samples.tsv"
sfacts_fit_inpath = (
    f"data/group/xjin_ucfmt_hmp2/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.world.nc"
)
ref_geno_inpath = f"data/species/sp-{species_id}/gtpro_ref.mgtp.nc"
spgc_meta_inpath = f"data/group/xjin_ucfmt_hmp2/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.gene99_new-v22-agg75.spgc_specgene-ref-t25-p95_ss-all_t-30_thresh-corr100-depth250.strain_meta.tsv"
ref_gene_copy_number_uhgg_inpath = (
    f"data/species/sp-{species_id}/gene75_new.reference_copy_number.nc"
)
spgc_gene_uhgg_inpath = f"data/group/xjin_ucfmt_hmp2/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.gene99_new-v22-agg75.spgc_specgene-ref-t25-p95_ss-all_t-30_thresh-corr100-depth250.strain_gene.tsv"
spgc_gene_uhgg_depth_inpath = f"data/group/xjin_ucfmt_hmp2/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.gene99_new-v22-agg75.spgc_specgene-ref-t25-p95_ss-all_t-30.strain_depth_ratio.tsv"
spgc_gene_uhgg_corr_inpath = f"data/group/xjin_ucfmt_hmp2/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.gene99_new-v22-agg75.spgc_specgene-ref-t25-p95_ss-all_t-30.strain_correlation.tsv"
uhgg_x_eggnog_inpath = (
    f"data/species/sp-{species_id}/pangenome.centroids.emapper.gene_x_eggnog.tsv"
)
uhgg_x_top_eggnog_inpath = (
    f"data/species/sp-{species_id}/pangenome_new.centroids.emapper.gene_x_top_eggnog.tsv"
)
uhgg_gene_length_inpath = f"ref/midasdb_uhgg_new/pangenomes/{species_id}/cluster_info.txt"
gene_annotations_inpath = f"data/species/sp-{species_id}/pangenome_new.centroids.emapper.d/proteins.emapper.annotations"
uhgg_depth_inpath = (
    f"data/group/xjin_ucfmt_hmp2/species/sp-{species_id}/r.proc.gene99_new-v22-agg75.depth2.nc"
)
species_depth_inpath = f"data/group/xjin_ucfmt_hmp2/species/sp-{species_id}/r.proc.gene99_new-v22-agg75.spgc_specgene-ref-t25-p95.species_depth.tsv"

mgen_inpath = f"meta/ucfmt/mgen.tsv"
sample_inpath = f"meta/ucfmt/sample.tsv"
subject_inpath = f"meta/ucfmt/subject.tsv"

html_outpath = f"data/group/xjin_ucfmt_hmp2/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.gene99_new-v22-agg75.spgc_specgene-ref-t25-p95_ss-all_t-30_thresh-corr100-depth250.donor_comparison.html"

## Data Loading / Validation

In [None]:
!date

#### Taxonomy

In [None]:
species_taxonomy = lib.thisproject.data.load_species_taxonomy(species_taxonomy_inpath)
species_taxonomy.loc[species_id]

### Gene Annotations

In [None]:
uhgg_x_eggnog = pd.read_table(uhgg_x_eggnog_inpath)

In [None]:
uhgg_x_top_eggnog = pd.read_table(uhgg_x_top_eggnog_inpath)

In [None]:
uhgg_gene_length = (
    pd.read_table(uhgg_gene_length_inpath)
    .groupby("centroid_75")
    .centroid_99_length.mean()
)

In [None]:
eggnog_column_names = "query seed_ortholog evalue score eggNOG_OGs max_annot_lvl COG_category Description Preferred_name GOs EC KEGG_ko KEGG_inpathway KEGG_Module KEGG_Reaction KEGG_rclass BRITE KEGG_TC CAZy BiGG_Reaction PFAMs".split(
    " "
)
_gene_annotations = (
    pd.read_table(
        gene_annotations_inpath,
        comment="#",
        names=eggnog_column_names,
        index_col="query",
    )
    .rename_axis(index="gene_id")
    .replace({"-": np.nan})
)
_gene_annotations = uhgg_gene_length.to_frame().join(_gene_annotations)

_gene_annotations.info()

In [None]:
gene_x_cog_category1 = (
    _gene_annotations.COG_category.fillna("-").apply(list).explode()[lambda x: x != "-"]
)
gene_x_cog_category1

In [None]:
cog_x_category = pd.read_table(
    "ref/cog-20.meta.tsv",
    names=["cog", "cog_category", "description", "short_name", "_4", "_5", "_6"],
    index_col="cog",
).cog_category
cog_x_category

In [None]:
_gene_annotations.eggNOG_OGs

In [None]:
gene_x_cog = (
    _gene_annotations.eggNOG_OGs.fillna("")
    .str.split(",")
    .explode()[lambda x: x.str.startswith("COG")]
    .str.split("@")
    .str[0]
)
gene_x_cog.value_counts().head()
gene_x_cog_category2 = gene_x_cog.map(cog_x_category).dropna().apply(list).explode()
gene_x_cog_category2

In [None]:
gene_x_cog_category = (
    pd.concat(
        [
            gene_x_cog_category1,
            gene_x_cog_category2,  # FIXME: Which metadata table do I want?
        ]
    )
    .reset_index()
    .drop_duplicates()
)

gene_x_cog_category.columns = ["centroid_75", "cog_category"]
gene_x_cog_category = gene_x_cog_category.set_index("centroid_75").cog_category
gene_x_cog_category.shape[0]

In [None]:
gene_annotations = _gene_annotations.assign(
    COG_category=gene_x_cog_category.sort_values()
    .reset_index()
    .groupby("centroid_75")
    .apply(lambda x: "".join(x.cog_category.values))
).assign(COG_category=lambda x: x.COG_category.fillna(""))

In [None]:
gene_x_cog_category_matrix = (
    gene_x_cog_category.reset_index()
    .assign(tally=True)
    .set_index(["centroid_75", "cog_category"])
    .tally.unstack("cog_category")
    .fillna(False)
    .reindex(gene_annotations.index, fill_value=False)
    .assign(no_category=lambda x: x.sum(1) == 0)
)
gene_x_cog_category_matrix

### Raw Data

In [None]:
species_depth = pd.read_table(
    species_depth_inpath, names=["sample", "depth"], index_col=["sample"]
).depth
species_depth

In [None]:
uhgg_depth = xr.load_dataarray(uhgg_depth_inpath)

In [None]:
uhgg_depth_ratio = uhgg_depth / species_depth.to_xarray().sel(sample=uhgg_depth.sample)

### SPGC Strains

In [None]:
sample_to_spgc = pd.read_table(sample_to_spgc_inpath, index_col="sample").strain.astype(
    str
)

spgc_palette = lib.plot.construct_ordered_palette(
    sample_to_spgc, other=(0.8, 0.8, 0.8, 1.0)
)

In [None]:
sfacts_fit = sf.World.load(sfacts_fit_inpath).drop_low_abundance_strains(0.5)
# Strains should be str not int.
sfacts_fit.data["strain"] = sfacts_fit.strain.values.astype(str)
print(dict(sfacts_fit.sizes))

spgc_est_geno = sf.Metagenotype(
    sfacts_fit.metagenotype.data.sel(sample=sample_to_spgc.index)
    .groupby(sample_to_spgc.to_xarray())
    .sum()
    .rename(strain="sample")
).to_estimated_genotype(pseudo=0)

# Pre-calculate shared heatmap decorations
position_ss = sfacts_fit.random_sample(
    position=min(500, sfacts_fit.sizes["position"])
).position
w = sfacts_fit.sel(position=position_ss)
sample_linkage = w.unifrac_linkage()
try:
    position_linkage = spgc_est_geno.linkage("position")
except ValueError as err:
    print(err)
    position_linkage = None
sample_colors = w.sample.to_series().map(sample_to_spgc).map(spgc_palette)
spgc_linkage = w.genotype.linkage()
spgc_colors = w.strain.to_series().map(spgc_palette)

In [None]:
g = sf.data.Genotype.concat(
    dict(mgen=spgc_est_geno, fit=sfacts_fit.genotype), dim="strain"
)
g_pdist = g.pdist()
g_linkage = g.linkage()
sf.plot.plot_genotype(
    g.sel(position=position_ss), transpose=True, col_linkage=g_linkage
)

### Ref Strains

In [None]:
# "Reference GT-Pro genotype"
ref_geno = sf.Metagenotype.load(ref_geno_inpath).to_estimated_genotype()
ref_geno.data["strain"] = ref_geno.strain.to_series().map(
    lambda s: "UHGG" + s[len("GUT_GENOME") :]
)
ref_geno.shape

In [None]:
reference_meta = (
    pd.read_table("ref/uhgg_genomes_all_4644.tsv", index_col="Genome")
    .rename_axis(index="genome_id")[
        lambda x: x.MGnify_accession == "MGYG-HGUT-" + species_id[1:]
    ]
    .rename(lambda s: "UHGG" + s[len("GUT_GENOME") :])
    .loc[ref_geno.strain]
)

In [None]:
bins = np.linspace(0, 100, num=101)
plt.hist(reference_meta.Completeness, bins=bins)
plt.hist(reference_meta.Contamination, bins=bins)
None

In [None]:
# NOTE: Select any ref genotype that is within the top-10 closest distances from an SPGC strain.
spgc_to_ref_geno_cdist = ref_geno.sel(
    position=spgc_est_geno.position,
    strain=idxwhere(
        (reference_meta.Completeness > 97) & (reference_meta.Contamination < 2)
    ),
).cdist(spgc_est_geno)
ref_list = list(
    spgc_to_ref_geno_cdist.apply(lambda x: x.sort_values().head(10).index)
    .stack()
    .unique()
)

In [None]:
spgc_meta = pd.read_table(spgc_meta_inpath, index_col="strain").rename(str)
print(spgc_meta.shape)
spgc_meta

In [None]:
ref_gene_copy_number_uhgg = xr.load_dataarray(ref_gene_copy_number_uhgg_inpath)
ref_gene_uhgg = (ref_gene_copy_number_uhgg > 0).astype(int).to_pandas().T

In [None]:
ref_gene_uhgg.shape

In [None]:
spgc_gene_uhgg = pd.read_table(spgc_gene_uhgg_inpath, index_col="gene_id").rename_axis(
    columns="strain"
)

In [None]:
ref_num_genes_uhgg = ref_gene_uhgg.sum()
spgc_num_genes_uhgg = spgc_gene_uhgg.sum()

In [None]:
ref_gene_uhgg_prevalence = ref_gene_uhgg[ref_list].mean(1)
ref_gene_uhgg_entropy = (
    -ref_gene_uhgg_prevalence * np.log2(ref_gene_uhgg_prevalence)
).fillna(0)
plt.hist(ref_gene_uhgg_entropy)

### Strain Selection / Filtering

In [None]:
x = spgc_meta[lambda x: x.species_gene_frac > 0.9].num_genes
y = ref_num_genes_uhgg
_df, _loc, _scale = sp.stats.t.fit(x.values, fix_df=2)
_dist0 = sp.stats.t(_df, _loc, _scale)
_dist1 = sp.stats.norm(_loc, _scale)

thresh_max_num_uhgg_genes = _dist1.ppf(0.999)
thresh_min_num_uhgg_genes = _dist1.ppf(0.001)


bins = np.linspace(0, x.max() * 1.5, num=50)
xx = np.linspace(0, x.max() * 1.5, num=1000)

plt.hist(x, density=True, bins=bins, alpha=0.2)
plt.hist(y, density=True, bins=bins, alpha=0.2)

plt.plot(xx, _dist0.pdf(xx), color="k")
plt.plot(xx, _dist1.pdf(xx), color="k", linestyle="--")
plt.axvline(thresh_max_num_uhgg_genes, lw=1, linestyle="--", color="k")
plt.axvline(thresh_min_num_uhgg_genes, lw=1, linestyle="--", color="k")

In [None]:
thresh_min_num_uhgg_genes, thresh_max_num_uhgg_genes

In [None]:
plt.scatter(
    "sum_depth",
    "species_gene_frac",
    c=(spgc_meta.num_genes - _loc) / _scale,
    data=spgc_meta,
    norm=mpl.colors.PowerNorm(1 / 1, vmin=-4, vmax=4),
)
plt.xscale("log")
plt.yscale("logit")
plt.axvline(1, lw=1, color="k", linestyle="--")
plt.axhline(0.9, lw=1, color="k", linestyle="--")
plt.colorbar()

In [None]:
# NOTE: Select SPGC strains that pass various filters
spgc_list = idxwhere(
    (spgc_meta.sum_depth > 1)
    & (spgc_meta.species_gene_frac > 0.9)
    & (spgc_num_genes_uhgg <= thresh_max_num_uhgg_genes)
    & (spgc_num_genes_uhgg >= thresh_min_num_uhgg_genes)
)

print(len(ref_list), len(spgc_list))

In [None]:
assert len(spgc_list) >= 1

In [None]:
strain_geno = sf.Genotype.concat(
    dict(
        ref=ref_geno.sel(strain=ref_list, position=spgc_est_geno.position),
        spgc=spgc_est_geno.sel(strain=spgc_list),
    ),
    dim="strain",
    rename=False,
).mlift("fillna", 0.5)

strain_geno_pdist = strain_geno.pdist(q=1)
strain_geno_linkage = strain_geno.linkage(pdist_kwargs=dict(q=1), optimal_ordering=True)

In [None]:
genome_type = reference_meta.Genome_type.reindex(strain_geno.strain)
genome_type.loc[genome_type.index.isin(spgc_list)] = "SPGC"
genome_type_order = ["Isolate", "MAG", "SPGC"]
genome_type_palette = lib.plot.construct_ordered_palette(genome_type_order)

fig, ax = plt.subplots(figsize=(2, 1))
for _genome_type in genome_type_order:
    ax.scatter([], [], color=genome_type_palette[_genome_type], label=_genome_type)
ax.legend()

In [None]:
_colors = genome_type.map(genome_type_palette)
sf.plot.plot_genotype(
    strain_geno.sel(position=position_ss),
    transpose=True,
    col_linkage=strain_geno_linkage,
    col_colors=_colors,
)

In [None]:
_colors = genome_type.map(genome_type_palette)
sns.clustermap(
    strain_geno_pdist,
    row_colors=_colors,
    col_colors=_colors,
    row_linkage=strain_geno_linkage,
    col_linkage=strain_geno_linkage,
    # figsize=(40, 40),
)

In [None]:
from scipy.spatial.distance import squareform

bins = np.linspace(0, 1)
plt.hist(
    squareform(strain_geno_pdist.loc[ref_list, ref_list]),
    bins=bins,
    histtype="step",
    density=True,
    label="ref-to-ref",
)
plt.hist(
    strain_geno_pdist.loc[ref_list, spgc_list].values.flatten(),
    bins=bins,
    histtype="step",
    density=True,
    label="spgc-to-ref",
)
plt.hist(
    squareform(strain_geno_pdist.loc[spgc_list, spgc_list]),
    bins=bins,
    histtype="step",
    density=True,
    label="spgc-to-spgc",
)
plt.legend()

In [None]:
bins = np.linspace(0, 1, num=200)

_pdist = strain_geno_pdist + np.eye(len(strain_geno_pdist))

plt.hist(
    _pdist.loc[ref_list, ref_list].min(),
    bins=bins,
    histtype="step",
    cumulative=True,
    density=True,
    label="ref-to-ref",
)
plt.hist(
    _pdist.loc[ref_list, spgc_list].min(),
    bins=bins,
    histtype="step",
    cumulative=True,
    density=True,
    label="spgc-to-ref",
)
plt.hist(
    _pdist.loc[spgc_list, ref_list].min(),
    bins=bins,
    histtype="step",
    cumulative=True,
    density=True,
    label="ref-to-spgc",
)
plt.hist(
    _pdist.loc[spgc_list, spgc_list].min(),
    bins=bins,
    histtype="step",
    cumulative=True,
    density=True,
    label="spgc-to-spgc",
)

plt.legend()
# plt.axvline(0.43, lw=1, linestyle='--', color='k')
plt.xlabel("minimum distance")
plt.ylabel("cumulative fraction")

In [None]:
spgc_gene_uhgg.shape, ref_gene_uhgg.shape

In [None]:
strain_gene_uhgg = pd.concat(
    [ref_gene_uhgg[ref_list], spgc_gene_uhgg[spgc_list]], axis=1
).fillna(0)[lambda x: x.sum(1) > 0]
strain_gene_uhgg.shape

In [None]:
bins = np.linspace(0, 10000)
plt.hist(
    strain_gene_uhgg[ref_list].sum(),
    bins=bins,
    histtype="step",
    label="ref",
    density=True,
)
plt.hist(
    strain_gene_uhgg[spgc_list].sum(),
    bins=bins,
    histtype="step",
    label="spgc",
    density=True,
)
plt.legend()

None

## Donor Comparison

In [None]:
mgen = pd.read_table(mgen_inpath, index_col="mgen_id")
sample = pd.read_table(sample_inpath, index_col="sample_id")
subject = pd.read_table(subject_inpath, index_col="subject_id")

mgen_meta = mgen.join(sample, on="sample_id").join(subject, on="subject_id")

# subject_uhgg_depth = mgen_meta[['subject_id']].join(uhgg_depth.to_pandas()).groupby('subject_id').mean().dropna()
# subject_mean_species_depth = mgen_meta[['subject_id']].join(species_depth).groupby('subject_id').depth.mean().dropna()

# total_subject_uhgg_depth = mgen_meta[['subject_id']].join(uhgg_depth.to_pandas()).groupby('subject_id').sum().dropna()
# total_subject_species_depth = mgen_meta[['subject_id']].join(species_depth).groupby('subject_id').depth.sum().dropna()
# subject_uhgg_depth_ratio = total_subject_uhgg_depth.divide(total_subject_species_depth, axis=0)

In [None]:
sample.sample_type.unique()

In [None]:
mgen_meta.donor_subject_id.unique()

In [None]:
sample_type_order = ["donor", "baseline", "maintenance", "followup", "post_antibiotic"]
donor_order = ["D0044", "D0097", "D0485", "D0065"]

subject_palette = lib.plot.construct_ordered_palette(
    subject.sort_values("donor_subject_id").index, cm="rainbow"
)
sample_type_palette = lib.plot.construct_ordered_palette(sample_type_order)

In [None]:
fig, ax = plt.subplots(figsize=(2, 2))
for s in donor_order:
    ax.scatter([], [], c=subject_palette[s], label=s)
ax.legend()

In [None]:
mgen_list = list(set(sfacts_fit.sample.values) & set(mgen_meta.index))

_col_colors = pd.DataFrame(
    dict(
        is_donor=(mgen_meta.sample_type == "donor").map(
            {True: "black", False: "darkgrey"}
        ),
        subject=mgen_meta.loc[mgen_list].subject_id.map(subject_palette),
        donor=mgen_meta.loc[mgen_list].donor_subject_id.map(subject_palette),
        spgc=sample_to_spgc.reindex(mgen_list).dropna().map(spgc_palette),
    )
)

_col_linkage = sfacts_fit.sel(sample=mgen_list).unifrac_linkage()

sf.plot.plot_community(
    sfacts_fit.sel(sample=mgen_list).drop_low_abundance_strains(0.05),
    col_colors=_col_colors,
    col_linkage=_col_linkage,
    row_linkage_func=lambda w: w.genotype.linkage(),
)
sf.plot.plot_metagenotype(
    sfacts_fit.sel(sample=mgen_list, position=position_ss).drop_low_abundance_strains(0.05),
    col_colors=_col_colors,
    col_linkage=_col_linkage,
)

In [None]:
# Most abundant strains in donor samples:

x = sfacts_fit.community.to_pandas().groupby(mgen_meta.subject_id).mean().T
print(x["D0044"].sort_values(ascending=False).head(3))
print(x["D0097"].sort_values(ascending=False).head(3))
spgc_meta.loc[x.idxmax()[["D0044", "D0097"]]]

In [None]:
donor_strain_pairs = idxwhere(
    sfacts_fit.community.to_pandas()
    .groupby(mgen_meta.subject_id)
    .mean()
    .groupby(subject.donor_subject_id)
    .mean()
    .stack()
    > 0.1
)
donor_strain_list = list(map(lambda x: x[1], donor_strain_pairs))
donor_strain_list
print(donor_strain_pairs)
spgc_meta.reindex(donor_strain_list)

In [None]:
# Most abundant strains in subjects with each donor:

x = sfacts_fit.community.to_pandas().groupby(mgen_meta.donor_subject_id).mean().T
print(x["D0044"].sort_values(ascending=False).head(3))
print(x["D0097"].sort_values(ascending=False).head(3))
d44_strain, d97_strain = x.idxmax()[["D0044", "D0097"]]
spgc_meta.loc[x.idxmax()[["D0044", "D0097"]]]

In [None]:
# Most frequently "pure" strains in subjects with each donor:

x = (
    sample_to_spgc.groupby(mgen_meta.donor_subject_id)
    .value_counts()
    .unstack(fill_value=0)
    .T
)

print(x["D0044"].sort_values(ascending=False).head(3))
print(x["D0097"].sort_values(ascending=False).head(3))
spgc_meta.loc[x.idxmax()[["D0044", "D0097"]]]

In [None]:
x = strain_geno_pdist

_colors = pd.DataFrame(
    dict(
        genome_type=genome_type.map(genome_type_palette),
        donor_strain=x.index.to_series()
        .map({d44_strain: "black", d97_strain: "white"})
        .fillna("grey"),
    )
)

sns.clustermap(
    x,
    row_colors=_colors,
    col_colors=_colors,
    row_linkage=strain_geno_linkage,
    col_linkage=strain_geno_linkage,
    xticklabels=0,
    yticklabels=0,
    # figsize=(40, 40),
)

In [None]:
plt.hist(squareform(strain_geno_pdist))
print(strain_geno_pdist.loc[d44_strain, d97_strain])
plt.axvline(
    strain_geno_pdist.loc[d44_strain, d97_strain], lw=1, linestyle="--", color="k"
)

In [None]:
_col_colors = pd.DataFrame(
    dict(
        is_donor=(mgen_meta.sample_type == "donor").map(
            {True: "black", False: "darkgrey"}
        ),
        subject=mgen_meta.loc[sfacts_fit.sample].subject_id.map(subject_palette),
        donor=mgen_meta.loc[sfacts_fit.sample].donor_subject_id.map(subject_palette),
        spgc=sample_to_spgc.map(spgc_palette),
    )
)

_col_linkage = sfacts_fit.unifrac_linkage()

sf.plot.plot_community(
    sfacts_fit,
    col_colors=_col_colors,
    col_linkage=_col_linkage,
    row_linkage_func=lambda w: w.genotype.linkage(),
)

In [None]:
mgen_meta.subject_id.value_counts().head(4).to_frame(name="num_mgen").join(subject)

In [None]:
mgen_meta[lambda x: x.donor_subject_id.isin(["D0097"])].subject_id.value_counts()

In [None]:
subject_lists = {
    "D0044": [
        "S0041",
        "S0053",
        "S0055",
        # 'S0056', 'S0061',
        # 'S0060', 'S0059'
    ],  # D0044
    "D0097": [
        # 'S0004',
        "S0001",
        "S0024",
        # 'S0027',
        "S0021",
    ],
}
all_subjects_list = subject_lists["D0044"] + subject_lists["D0097"]
drop_mgen_list = ["SS01117", "SS01120", "SS01126", "SS01185", "SS01008"]
sample_type_specific_x = {
    "baseline": 0,
    "pre_maintenance_1": 1,
    "pre_maintenance_2": 2,
    "pre_maintenance_3": 3,
    "pre_maintenance_4": 4,
    "pre_maintenance_5": 5,
    "pre_maintenance_6": 6,
    "followup_1": 7,
    "followup_2": 8,
}
sample_type_specific_order = list(sample_type_specific_x.keys())

top_strains_list = (
    sfacts_fit.community.to_pandas()
    .reindex(
        mgen_meta.drop(drop_mgen_list)[
            lambda x: x.subject_id.isin(all_subjects_list + ["D0044", "D0097"])
        ].index.to_list()
    )
    .mean()
    .sort_values(ascending=False)
    .head(20)
    .sample(frac=1.0, random_state=3)
    .index
)
_palette = lib.plot.construct_ordered_palette(
    top_strains_list, cm="Spectral", extend={d44_strain: "blue", d97_strain: "green"}
)

ncols = len(subject_lists["D0044"])
nrows = 2
fig, axs = plt.subplots(
    nrows, ncols, figsize=(4 * ncols, 3 * nrows), sharex=True, sharey=True
)

for donor_id, ax_row in zip(["D0044", "D0097"], axs):
    for subject_id, ax in zip(subject_lists[donor_id], ax_row):
        ax.set_title(subject_id)
        # Select data
        # subject_id = 'S0004'  # S0041 has D0044 and 12 samples
        _subject_baseline_sample_list = mgen_meta[
            lambda x: (x.subject_id == subject_id)
            & (x.sample_type_specific == "baseline")
            & (~x.index.isin(drop_mgen_list))
        ].index
        assert (
            len(_subject_baseline_sample_list) == 1
        ), f"No baseline sample for {subject_id}"
        subject_baseline_strain_list = idxwhere(
            sfacts_fit.community.data.to_pandas()
            .reindex([_subject_baseline_sample_list[0]], fill_value=0)
            .squeeze()
            > 1e-1
        )

        # Reshape data
        _strain_list = set(donor_strain_list) | set(subject_baseline_strain_list)
        _meta = (
            mgen_meta[
                lambda x: (x.subject_id == subject_id) & (~x.index.isin(drop_mgen_list))
            ]
            # .reset_index()
        )
        _strain_frac = sfacts_fit.community.to_pandas().reindex(_meta.index).dropna()
        # .set_index('sample_type_specific').reindex(sample_type_specific_order)
        # .dropna(subset='mgen_id')
        _meta, _strain_frac = lib.pandas_util.align_indexes(_meta, _strain_frac)
        _strain_frac[-1] = 1 - _strain_frac[list(_strain_list)].sum(1)
        _strain_frac = _strain_frac[list(_strain_list | {-1})]
        _strain_list |= {-1}
        _strain_depth = _strain_frac.multiply(species_depth, axis=0)

        _strain_frac = (
            _strain_frac.rename(_meta.sample_type_specific)
            .reindex(sample_type_specific_order)
            .dropna()
        )
        _strain_depth = (
            _strain_depth.rename(_meta.sample_type_specific)
            .reindex(sample_type_specific_order)
            .dropna()
        )
        _meta = (
            _meta.rename(_meta.sample_type_specific)
            .reindex(sample_type_specific_order)
            .dropna(subset=["sample_id"])
        )

        for _strain in _strain_list:
            lw = {True: 4, False: 2}[_strain in donor_strain_list]
            s = {True: 10, False: 5}[_strain in donor_strain_list]
            ax.plot(
                _meta.index.map(sample_type_specific_x),
                _strain_frac[_strain],
                color=_palette[_strain],
                lw=lw,
                alpha=0.9,
                marker=".",
                markersize=s,
            )

    # TODO: Strain legend?
    # label=_strain, color=spgc_palette[_strain]
    # ax.set_yscale('log')
    ax.set_yscale("symlog", linthresh=1e-2)
    ax.set_ylim(5e-3, 2e0)
    # ax.legend(bbox_to_anchor=(1, 1))
axs[0, 0].set_xticks(np.arange(9))
fig.tight_layout()

# fig, ax = plt.subplots()
# for _strain_id in top_strains_list:
#     ax.scatter([], [], color=_palette[_strain_id], label=_strain_id)
# ax.legend()

## Strain Geno/Gene Spaces

In [None]:
# NOTE: This will take ~2 minutes to run for 40,000 genes.
gene_uhgg_cdmat = sp.spatial.distance.pdist(strain_gene_uhgg, metric="cosine")
gene_uhgg_pdist = pd.DataFrame(
    sp.spatial.distance.squareform(gene_uhgg_cdmat),
    index=strain_gene_uhgg.index,
    columns=strain_gene_uhgg.index,
)

In [None]:
gene_uhgg_linkage = sp.cluster.hierarchy.linkage(gene_uhgg_cdmat, method="average")
gene_uhgg_linkage.shape

In [None]:
strain_gene_uhgg_cdmat = sp.spatial.distance.pdist(
    strain_gene_uhgg.T,
    metric="cosine",
    w=ref_gene_uhgg_entropy.loc[strain_gene_uhgg.index],
)
strain_gene_uhgg_pdist = pd.DataFrame(
    sp.spatial.distance.squareform(strain_gene_uhgg_cdmat),
    index=strain_gene_uhgg.columns,
    columns=strain_gene_uhgg.columns,
)

assert (strain_geno_pdist.index == strain_gene_uhgg_pdist.index).all()

In [None]:
strain_gene_uhgg_linkage = sp.cluster.hierarchy.linkage(
    strain_gene_uhgg_cdmat, method="average", optimal_ordering=True
)

In [None]:
x = strain_gene_uhgg
_col_linkage = strain_geno_linkage
_row_linkage = gene_uhgg_linkage
# Order x by leaf order.
# See <https://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.hierarchy.ClusterNode.pre_order.html#scipy.cluster.hierarchy.ClusterNode.pre_order>
x = x.iloc[sp.cluster.hierarchy.to_tree(_row_linkage).pre_order(lambda x: x.id)]

_col_colors = pd.DataFrame(
    dict(
        genome_type=genome_type.map(genome_type_palette),
        donor_strain=x.columns.to_series()
        .map({d44_strain: "black", d97_strain: "white"})
        .fillna("grey"),
    )
)

sns.clustermap(
    x,
    row_cluster=False,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    xticklabels=0,
    yticklabels=0,
)

In [None]:
x = strain_gene_uhgg
_col_linkage = strain_gene_uhgg_linkage
_row_linkage = gene_uhgg_linkage
# Order x by leaf order.
# See <https://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.hierarchy.ClusterNode.pre_order.html#scipy.cluster.hierarchy.ClusterNode.pre_order>
x = x.iloc[sp.cluster.hierarchy.to_tree(_row_linkage).pre_order(lambda x: x.id)]

_col_colors = pd.DataFrame(
    dict(
        genome_type=genome_type.map(genome_type_palette),
        donor_strain=x.columns.to_series()
        .map({d44_strain: "black", d97_strain: "white"})
        .fillna("grey"),
    )
)

sns.clustermap(
    x,
    row_cluster=False,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    xticklabels=0,
    yticklabels=0,
)

In [None]:
spgc_gene_uhgg_depth = (
    pd.read_table(spgc_gene_uhgg_depth_inpath, index_col=["gene_id", "strain"])
    .depth.unstack("strain")
    .rename(columns=str)
)
spgc_gene_uhgg_corr = (
    pd.read_table(spgc_gene_uhgg_corr_inpath, index_col=["gene_id", "strain"])
    .correlation.unstack("strain")
    .rename(columns=str)
)

In [None]:
_ref = ref_gene_uhgg[ref_list]
_spgc = spgc_gene_uhgg_depth[spgc_list]
strain_gene_uhgg_depth = (
    pd.concat([_ref, _spgc], axis=1).fillna(0).loc[strain_gene_uhgg.index]
)

In [None]:
_ref = ref_gene_uhgg[ref_list]
_spgc = spgc_gene_uhgg_corr[spgc_list]
strain_gene_uhgg_corr = (
    pd.concat([_ref, _spgc], axis=1).fillna(0).loc[strain_gene_uhgg.index]
)

In [None]:
x = strain_gene_uhgg_depth
_col_linkage = strain_gene_uhgg_linkage
_row_linkage = gene_uhgg_linkage
# Order x by leaf order.
# See <https://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.hierarchy.ClusterNode.pre_order.html#scipy.cluster.hierarchy.ClusterNode.pre_order>
x = x.iloc[sp.cluster.hierarchy.to_tree(_row_linkage).pre_order(lambda x: x.id)]

_col_colors = pd.DataFrame(
    dict(
        genome_type=genome_type.map(genome_type_palette),
        donor_strain=x.columns.to_series()
        .map({d44_strain: "black", d97_strain: "white"})
        .fillna("grey"),
    )
)

sns.clustermap(
    x,
    row_cluster=False,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    norm=mpl.colors.PowerNorm(1 / 2, vmin=0, vmax=2),
    xticklabels=0,
    yticklabels=0,
)

In [None]:
x = strain_gene_uhgg_corr
_col_linkage = strain_gene_uhgg_linkage
_row_linkage = gene_uhgg_linkage
# Order x by leaf order.
# See <https://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.hierarchy.ClusterNode.pre_order.html#scipy.cluster.hierarchy.ClusterNode.pre_order>
x = x.iloc[sp.cluster.hierarchy.to_tree(_row_linkage).pre_order(lambda x: x.id)]

_col_colors = pd.DataFrame(
    dict(
        genome_type=genome_type.map(genome_type_palette),
        donor_strain=x.columns.to_series()
        .map({d44_strain: "black", d97_strain: "white"})
        .fillna("grey"),
    )
)

sns.clustermap(x, row_cluster=False, col_linkage=_col_linkage, col_colors=_col_colors)

In [None]:
_pdistA = strain_geno_pdist
_pdistB = strain_gene_uhgg_pdist

plt.scatter(
    squareform(_pdistA.loc[ref_list, ref_list]),
    squareform(_pdistB.loc[ref_list, ref_list]),
    s=1,
    alpha=0.5,
    label="ref-ref",
)
print(
    "ref-ref",
    sp.stats.pearsonr(
        squareform(_pdistA.loc[ref_list, ref_list]),
        squareform(_pdistB.loc[ref_list, ref_list]),
    ),
)

plt.scatter(
    squareform(_pdistA.loc[spgc_list, spgc_list]),
    squareform(_pdistB.loc[spgc_list, spgc_list]),
    s=1,
    alpha=0.5,
    label="spgc-spgc",
)
if len(spgc_list) > 2:
    print(
        "spgc-spgc",
        sp.stats.pearsonr(
            squareform(_pdistA.loc[spgc_list, spgc_list]),
            squareform(_pdistB.loc[spgc_list, spgc_list]),
        ),
    )

plt.scatter(
    _pdistA.loc[spgc_list, ref_list].values.flatten(),
    _pdistB.loc[spgc_list, ref_list].values.flatten(),
    s=1,
    alpha=0.5,
    label="spgc-ref",
)
print(
    "spgc-ref",
    sp.stats.pearsonr(
        _pdistA.loc[spgc_list, ref_list].values.flatten(),
        _pdistB.loc[spgc_list, ref_list].values.flatten(),
    ),
)

plt.legend(markerscale=5)

In [None]:
bins = np.linspace(0, 1, num=200)

_pdistA = strain_geno_pdist
_pdistB = strain_gene_uhgg_pdist


# Remove the diagonal from "minimum distance".
_pdistA = _pdistA + np.eye(len(_pdistA))
_pdistB = _pdistB + np.eye(len(_pdistB))

plt.scatter(
    _pdistA.loc[ref_list, ref_list].min(),
    _pdistB.loc[ref_list, ref_list].min(),
    s=10,
    alpha=0.5,
    label="ref-to-ref",
)
plt.scatter(
    _pdistA.loc[spgc_list, spgc_list].min(),
    _pdistB.loc[spgc_list, spgc_list].min(),
    s=10,
    alpha=0.5,
    label="spgc-to-spgc",
)
plt.scatter(
    _pdistA.loc[spgc_list, ref_list].min(),
    _pdistB.loc[spgc_list, ref_list].min(),
    s=10,
    alpha=0.5,
    label="ref-to-spgc",
)
plt.scatter(
    _pdistA.loc[ref_list, spgc_list].min(),
    _pdistB.loc[ref_list, spgc_list].min(),
    s=10,
    alpha=0.5,
    label="spgc-to-ref",
)
plt.xlabel("minimum_genotype_diss")
plt.ylabel("minimum_gene_diss")

plt.legend()

In [None]:
_pdistA = strain_geno_pdist
_pdistB = strain_gene_uhgg_pdist

plt.scatter(
    squareform(_pdistA.loc[ref_list, ref_list]),
    squareform(_pdistB.loc[ref_list, ref_list]),
    s=1,
    alpha=0.5,
    color="k",
    label="ref-ref",
)

for spgc_strain_id in spgc_list:
    plt.scatter(
        _pdistA.loc[spgc_strain_id, ref_list],
        _pdistB.loc[spgc_strain_id, ref_list],
        s=1,
        alpha=0.5,
        color=spgc_palette[spgc_strain_id],
        # label='spgc-ref',
    )
    print(
        spgc_strain_id,
        sp.stats.pearsonr(
            _pdistA.loc[spgc_strain_id, ref_list], _pdistB.loc[spgc_strain_id, ref_list]
        ),
    )

plt.legend()

### Gene Filtering

#### Enriched/Depleted in SPGC Strains

In [None]:
ref_gene_uhgg_prevalence = strain_gene_uhgg[ref_list].mean(1)
spgc_gene_uhgg_prevalence = strain_gene_uhgg[spgc_list].mean(1)

In [None]:
x = ref_gene_uhgg_prevalence
y = spgc_gene_uhgg_prevalence

print(sp.stats.pearsonr(x, y))

fig, axs = plt.subplots(2, figsize=(5, 10))

bins0 = np.linspace(0.0, 1.0, num=50)
axs[0].hist2d(x, y, bins=bins0, norm=mpl.colors.PowerNorm(1 / 3, vmin=0, vmax=1e3))

bins1 = np.linspace(0.1, 0.9, num=40)
axs[1].hist2d(x, y, bins=bins1, norm=mpl.colors.PowerNorm(1 / 3))
axs[1].set_xlabel("reference prevalence")
axs[1].set_ylabel("inferred prevalence")
None

In [None]:
spgc_extremely_enriched = idxwhere(
    (spgc_gene_uhgg_prevalence - ref_gene_uhgg_prevalence) > 0.9
)
spgc_extremely_depleted = idxwhere(
    (spgc_gene_uhgg_prevalence - ref_gene_uhgg_prevalence) < -0.9
)
spgc_very_enriched = idxwhere(
    (spgc_gene_uhgg_prevalence - ref_gene_uhgg_prevalence) > 0.5
)
spgc_very_depleted = idxwhere(
    (spgc_gene_uhgg_prevalence - ref_gene_uhgg_prevalence) < -0.5
)
spgc_highly_enriched = idxwhere(
    (spgc_gene_uhgg_prevalence - ref_gene_uhgg_prevalence) > 0.35
)
spgc_highly_depleted = idxwhere(
    (spgc_gene_uhgg_prevalence - ref_gene_uhgg_prevalence) < -0.35
)
spgc_enriched = idxwhere((spgc_gene_uhgg_prevalence - ref_gene_uhgg_prevalence) > 0.25)
spgc_depleted = idxwhere((spgc_gene_uhgg_prevalence - ref_gene_uhgg_prevalence) < -0.25)
spgc_similar = idxwhere(
    ((spgc_gene_uhgg_prevalence - ref_gene_uhgg_prevalence) > -0.25)
    & ((spgc_gene_uhgg_prevalence - ref_gene_uhgg_prevalence) < 0.25)
)

#### Gene Length / Singletons

In [None]:
for spgc_enrichment_class_list in [
    spgc_extremely_enriched,
    spgc_very_enriched,
    spgc_highly_enriched,
    spgc_enriched,
    spgc_similar,
    spgc_depleted,
    spgc_highly_depleted,
    spgc_very_depleted,
    spgc_extremely_depleted,
]:
    print(
        len(spgc_enrichment_class_list),
        uhgg_gene_length.loc[spgc_enrichment_class_list].mean(),
        uhgg_gene_length.loc[spgc_enrichment_class_list].std(),
    )

In [None]:
short_genes = idxwhere(uhgg_gene_length < 300)
singleton_genes = idxwhere(strain_gene_uhgg.sum(1) <= 1)

#### Final Filter

In [None]:
drop_list = spgc_highly_enriched + spgc_highly_depleted + short_genes + singleton_genes
print(
    len(spgc_highly_enriched),
    len(spgc_highly_depleted),
    len(short_genes),
    len(singleton_genes),
)

strain_uhgg_filt = strain_gene_uhgg.drop(index=drop_list, errors="ignore")
print(strain_gene_uhgg.shape[0], strain_uhgg_filt.shape[0])

In [None]:
bins = np.linspace(0, strain_gene_uhgg[spgc_list + ref_list].sum().max(), num=50)

plt.hist(
    strain_gene_uhgg[ref_list].sum(),
    bins=bins,
    histtype="stepfilled",
    label="ref (uhgg)",
    density=True,
    color="tab:blue",
    alpha=0.5,
)
plt.hist(
    strain_gene_uhgg[spgc_list].sum(),
    bins=bins,
    histtype="stepfilled",
    label="spgc (uhgg)",
    density=True,
    color="tab:orange",
    alpha=0.5,
)

plt.hist(
    strain_uhgg_filt[ref_list].sum(),
    bins=bins,
    histtype="step",
    label="ref (filtered uhgg)",
    density=True,
    linestyle="-",
    color="tab:blue",
    lw=2,
)
plt.hist(
    strain_uhgg_filt[spgc_list].sum(),
    bins=bins,
    histtype="step",
    label="spgc (filtered uhgg)",
    density=True,
    linestyle="-",
    color="tab:orange",
    lw=2,
)

# plt.hist(strain_gene_eggnog[ref_list].sum(), bins=bins, histtype='stepfilled', label='ref (eggnog)', density=True, color='tab:blue', alpha=0.3)
# plt.hist(strain_gene_eggnog[spgc_list].sum(), bins=bins, histtype='stepfilled', label='spgc (eggnog)', density=True, color='tab:orange', alpha=0.3)

# plt.hist(strain_gene_filt_eggnog[ref_list].sum(), bins=bins, histtype='step', label='ref (filtered eggnog)', density=True, linestyle='--', color='tab:blue', lw=1.5)
# plt.hist(strain_gene_filt_eggnog[spgc_list].sum(), bins=bins, histtype='step', label='spgc (filtered eggnog)', density=True, linestyle='--', color='tab:orange', lw=1.5)

plt.legend()

None

In [None]:
uhgg_filt_cdmat = sp.spatial.distance.pdist(strain_uhgg_filt, metric="cosine")
uhgg_filt_pdist = pd.DataFrame(
    sp.spatial.distance.squareform(uhgg_filt_cdmat),
    index=strain_uhgg_filt.index,
    columns=strain_uhgg_filt.index,
)

In [None]:
uhgg_filt_linkage = sp.cluster.hierarchy.linkage(uhgg_filt_cdmat, method="average")
uhgg_filt_linkage.shape

In [None]:
strain_uhgg_filt_cdmat = sp.spatial.distance.pdist(
    strain_uhgg_filt.T,
    metric="cosine",
    w=ref_gene_uhgg_entropy.loc[strain_uhgg_filt.index],
)
strain_uhgg_filt_pdist = pd.DataFrame(
    sp.spatial.distance.squareform(strain_uhgg_filt_cdmat),
    index=strain_uhgg_filt.columns,
    columns=strain_uhgg_filt.columns,
)

In [None]:
strain_uhgg_filt_linkage = sp.cluster.hierarchy.linkage(
    strain_uhgg_filt_cdmat, method="average", optimal_ordering=True
)

In [None]:
x = strain_uhgg_filt
_col_linkage = strain_geno_linkage
_row_linkage = uhgg_filt_linkage
# Order x by leaf order.
# See <https://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.hierarchy.ClusterNode.pre_order.html#scipy.cluster.hierarchy.ClusterNode.pre_order>
x = x.iloc[sp.cluster.hierarchy.to_tree(_row_linkage).pre_order(lambda x: x.id)]

_col_colors = pd.DataFrame(
    dict(
        genome_type=genome_type.map(genome_type_palette),
        donor_strain=x.columns.to_series()
        .map({d44_strain: "black", d97_strain: "white"})
        .fillna("grey"),
    )
)

sns.clustermap(x, row_cluster=False, col_linkage=_col_linkage, col_colors=_col_colors)

In [None]:
x = strain_uhgg_filt
_col_linkage = strain_uhgg_filt_linkage
_row_linkage = uhgg_filt_linkage
# Order x by leaf order.
# See <https://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.hierarchy.ClusterNode.pre_order.html#scipy.cluster.hierarchy.ClusterNode.pre_order>
x = x.iloc[sp.cluster.hierarchy.to_tree(_row_linkage).pre_order(lambda x: x.id)]

_col_colors = pd.DataFrame(
    dict(
        genome_type=genome_type.map(genome_type_palette),
        donor_strain=x.columns.to_series()
        .map({d44_strain: "black", d97_strain: "white"})
        .fillna("grey"),
    )
)

sns.clustermap(x, row_cluster=False, col_linkage=_col_linkage, col_colors=_col_colors)

### Gene Content Divergence from Refs

Are SPGC Genomes Believably Distinct?

Test if SPGC Strains are surprisingly distinct from refs in gene content relative to their distinction in genotype

In [None]:
_pdistA = strain_geno_pdist
_pdistB = strain_uhgg_filt_pdist

plt.scatter(
    squareform(_pdistA.loc[ref_list, ref_list]),
    squareform(_pdistB.loc[ref_list, ref_list]),
    s=1,
    alpha=0.5,
    label="ref-ref",
)
print(
    "ref-ref",
    sp.stats.pearsonr(
        squareform(_pdistA.loc[ref_list, ref_list]),
        squareform(_pdistB.loc[ref_list, ref_list]),
    ),
)

plt.scatter(
    squareform(_pdistA.loc[spgc_list, spgc_list]),
    squareform(_pdistB.loc[spgc_list, spgc_list]),
    s=1,
    alpha=0.5,
    label="spgc-spgc",
)
if len(spgc_list) > 2:
    print(
        "spgc-spgc",
        sp.stats.pearsonr(
            squareform(_pdistA.loc[spgc_list, spgc_list]),
            squareform(_pdistB.loc[spgc_list, spgc_list]),
        ),
    )

plt.scatter(
    _pdistA.loc[spgc_list, ref_list].values.flatten(),
    _pdistB.loc[spgc_list, ref_list].values.flatten(),
    s=1,
    alpha=0.5,
    label="spgc-ref",
)
print(
    "spgc-ref",
    sp.stats.pearsonr(
        _pdistA.loc[spgc_list, ref_list].values.flatten(),
        _pdistB.loc[spgc_list, ref_list].values.flatten(),
    ),
)

plt.legend(markerscale=5)

In [None]:
_pdistA = strain_geno_pdist
_pdistB = strain_uhgg_filt_pdist


# Remove the diagonal from "minimum distance".
_pdistA = _pdistA + np.eye(len(_pdistA))
_pdistB = _pdistB + np.eye(len(_pdistB))

plt.scatter(
    _pdistA.loc[ref_list, ref_list].min(),
    _pdistB.loc[ref_list, ref_list].min(),
    s=10,
    alpha=0.5,
    label="ref-to-ref",
)
plt.scatter(
    _pdistA.loc[ref_list, spgc_list].min(),
    _pdistB.loc[ref_list, spgc_list].min(),
    s=10,
    alpha=0.5,
    label="spgc-to-ref",
)
plt.scatter(
    _pdistA.loc[spgc_list, spgc_list].min(),
    _pdistB.loc[spgc_list, spgc_list].min(),
    s=10,
    alpha=0.5,
    label="spgc-to-spgc",
)
plt.scatter(
    _pdistA.loc[spgc_list, ref_list].min(),
    _pdistB.loc[spgc_list, ref_list].min(),
    s=10,
    alpha=0.5,
    label="ref-to-spgc",
)
plt.xlabel("minimum_genotype_diss")
plt.ylabel("minimum_gene_diss")

plt.legend()

In [None]:
import statsmodels.formula.api as smf

_pdistA = strain_geno_pdist
_pdistB = strain_uhgg_filt_pdist

geno_pdist_adjust = 1 / sfacts_fit.sizes["position"]
# gene_pdist_adjust = (1 / strain_gene_filt_eggnog.shape[0])


# Remove the diagonal from "minimum distance".
_pdistA = _pdistA + np.eye(len(_pdistA))
_pdistB = _pdistB + np.eye(len(_pdistB))

ref_to_ref_min_pdistA = _pdistA.loc[ref_list, ref_list].min()
ref_to_ref_min_pdistB = _pdistB.loc[ref_list, ref_list].min()
spgc_to_ref_min_pdistA = _pdistA.loc[ref_list, spgc_list].min()
spgc_to_ref_min_pdistB = _pdistB.loc[ref_list, spgc_list].min()

d0 = pd.DataFrame(
    dict(
        geno_dist=pd.concat([ref_to_ref_min_pdistA, spgc_to_ref_min_pdistA]),
        gene_dist=pd.concat([ref_to_ref_min_pdistB, spgc_to_ref_min_pdistB]),
        spgc=np.concatenate(
            [np.zeros_like(ref_to_ref_min_pdistA), np.ones_like(spgc_to_ref_min_pdistA)]
        ).astype(bool),
    )
)  # .assign(
#     geno_dist_pc=lambda x: x.geno_dist + geno_pdist_adjust,
#     gene_dist_pc=lambda x: x.gene_dist + gene_pdist_adjust,
# )
fit = smf.ols(
    f"gene_dist ~ np.log2(geno_dist + {geno_pdist_adjust}) * spgc", data=d0
).fit()
d1 = d0.assign(
    gene_dist_predict=lambda x: fit.predict(), gene_dist_resid_pearson=fit.resid_pearson
).sort_values("geno_dist")

fig, axs = plt.subplots(1, 2, figsize=(10, 4))

ax = axs[0]
for _spgc, d2 in d1.groupby("spgc"):
    ax.scatter("geno_dist", "gene_dist", data=d2, label=_spgc)
    ax.plot("geno_dist", "gene_dist_predict", data=d2, label="__nolegend__")

ax = axs[1]
for _spgc, d2 in d1.groupby("spgc"):
    ax.scatter("geno_dist", "gene_dist", data=d2, label=_spgc)
    ax.plot("geno_dist", "gene_dist_predict", data=d2, label="__nolegend__")
ax.legend(title="SPGC")
ax.set_xscale("symlog", linthresh=1e-4)
fit.summary()

In [None]:
for _spgc, d2 in d1.groupby("spgc"):
    plt.scatter("gene_dist_predict", "gene_dist_resid_pearson", data=d2, label=_spgc)
plt.legend()

In [None]:
_pdistA = strain_geno_pdist
_pdistB = strain_uhgg_filt_pdist

plt.scatter(
    squareform(_pdistA.loc[ref_list, ref_list]),
    squareform(_pdistB.loc[ref_list, ref_list]),
    s=1,
    alpha=0.5,
    color="k",
    label="ref-ref",
)

for spgc_strain_id in spgc_list:
    plt.scatter(
        _pdistA.loc[spgc_strain_id, ref_list],
        _pdistB.loc[spgc_strain_id, ref_list],
        s=1,
        alpha=0.5,
        color=spgc_palette[spgc_strain_id],
        # label='spgc-ref',
    )
    print(
        spgc_strain_id,
        sp.stats.pearsonr(
            _pdistA.loc[spgc_strain_id, ref_list], _pdistB.loc[spgc_strain_id, ref_list]
        ),
    )

plt.legend()

## Strain Diversity Analysis

In [None]:
x = strain_geno_pdist.replace({0: np.nan})

_colors = pd.DataFrame(
    dict(
        genome_type=genome_type.map(genome_type_palette),
        donor_strain=x.columns.to_series()
        .map({d44_strain: "black", d97_strain: "white"})
        .fillna("grey"),
    )
)

g = sns.clustermap(
    x,
    row_colors=_colors,
    col_colors=_colors,
    row_linkage=strain_uhgg_filt_linkage,
    col_linkage=strain_geno_linkage,
    # figsize=(40, 40),
)
g.ax_heatmap.set_facecolor("aqua")

In [None]:
x = strain_uhgg_filt_pdist.replace({0: np.nan})

_colors = pd.DataFrame(
    dict(
        g=genome_type.map(genome_type_palette),
        d=x.columns.to_series()
        .map(
            {d44_strain: spgc_palette[d44_strain], d97_strain: spgc_palette[d97_strain]}
        )
        .fillna("grey"),
    )
)

g = sns.clustermap(
    x,
    row_colors=_colors,
    col_colors=_colors,
    row_linkage=strain_uhgg_filt_linkage,
    col_linkage=strain_geno_linkage,
    # figsize=(40, 40),
    xticklabels=0,
    yticklabels=0,
    vmin=0,
    vmax=1,
)
g.ax_heatmap.set_facecolor("aqua")
# g.cax.set_visible(False)

In [None]:
x = strain_uhgg_filt_pdist.replace({0: np.nan})

_colors = pd.DataFrame(
    dict(
        g=genome_type.map(genome_type_palette),
        # d=x.columns.to_series().map({d44_strain: spgc_palette[d44_strain], d97_strain: spgc_palette[d97_strain]}).fillna('grey'),
    )
)

g = sns.clustermap(
    x,
    row_colors=_colors,
    col_colors=_colors,
    row_linkage=strain_uhgg_filt_linkage,
    col_linkage=strain_geno_linkage,
    # figsize=(40, 40),
    xticklabels=0,
    yticklabels=0,
    vmin=0,
    vmax=1,
    tree_kws=dict(lw=1.5),
)
g.ax_heatmap.set_facecolor("aqua")
# g.cax.set_visible(False)

g.ax_heatmap.set_facecolor("aqua")
g.cax.set_visible(False)
g.ax_col_dendrogram.set_ylim(0.1, 1)
g.ax_row_dendrogram.set_xlim(1, 0.1)

### Prevalence Comparisons

In [None]:
ref_uhgg_prevalence = strain_uhgg_filt[ref_list].mean(1)
spgc_uhgg_prevalence = strain_uhgg_filt[spgc_list].mean(1)
gene_prevalence = strain_uhgg_filt.mean(1)

In [None]:
x = ref_uhgg_prevalence
y = spgc_uhgg_prevalence

print(sp.stats.pearsonr(x, y))

fig, axs = plt.subplots(2, figsize=(5, 10))

bins0 = np.linspace(0.0, 1.0, num=50)
axs[0].hist2d(x, y, bins=bins0, norm=mpl.colors.PowerNorm(1 / 3, vmin=0, vmax=1e3))

bins1 = np.linspace(0.1, 0.9, num=40)
axs[1].hist2d(x, y, bins=bins1, norm=mpl.colors.PowerNorm(1 / 3))
axs[1].set_xlabel("reference prevalence")
axs[1].set_ylabel("inferred prevalence")
None

### Core / Shell / Cloud Pangenome

In [None]:
bins = np.linspace(0, 1, num=51)


def _assign_prevalence_class(p):
    if p > 0.95:
        return "core"
    elif p > 0.1:
        return "shell"
    elif p < 0.1:
        return "cloud"


prevalence_class_order = ["core", "shell", "cloud"]
prevalence_class_palette = {
    "core": "tab:blue",
    "shell": "tab:orange",
    "cloud": "tab:green",
}

spgc_uhgg_class = spgc_uhgg_prevalence.map(_assign_prevalence_class)

for prevalence_class in prevalence_class_order:
    plt.hist(
        spgc_uhgg_prevalence[spgc_uhgg_class == prevalence_class],
        bins=bins,
        label=prevalence_class,
        color=prevalence_class_palette[prevalence_class],
    )
plt.legend()
plt.yscale("log")
# core_genes = idxwhere(strain_uhgg_filt_prevalence > 0.9)
# shell_genes = idxwhere((strain_uhgg_filt_prevalence < 0.9) & (strain_uhgg_filt_prevalence > 0.1))
# cloud_genes = idxwhere(strain_uhgg_filt_prevalence < 0.1)

In [None]:
import matplotlib.gridspec as gridspec

d = strain_gene_uhgg.groupby(spgc_uhgg_class).sum().T[prevalence_class_order]

fig = plt.figure(figsize=(10, 5), tight_layout=True)

gs = gridspec.GridSpec(
    1, 2, width_ratios=[len(ref_list) / len(spgc_list), 1], figure=fig
)
ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)

# , axs = plt.subplots(1, 2, figsize=(14, 5), sharey=True)
for ax, strain_list, title in zip(
    [ax0, ax1], [ref_list, spgc_list], ["references", "spgc"]
):
    # TODO: Fix the colors to always match the previous palette.
    d.loc[strain_list].sort_values("shell").plot.bar(stacked=True, ax=ax, width=1)
    ax.set_title(title)
    ax.set_xticks([])
    ax.set_aspect(0.08)
    ax.legend_.set_visible(False)
ax0.legend_.set_visible(True)

In [None]:
gene_annotations.groupby(spgc_uhgg_class).centroid_99_length.quantile(
    [0.1, 0.25, 0.5, 0.75, 0.9]
).unstack()

In [None]:
cog_category_order = gene_x_cog_category_matrix.columns
cog_category_palette = lib.plot.construct_ordered_palette(
    cog_category_order, cm="rainbow", extend=dict(no_category="grey")
)
for cog_category in cog_category_order:
    plt.scatter([], [], c=cog_category_palette[cog_category], label=cog_category)
plt.legend(ncols=4)

lib.plot.hide_axes_and_spines()

In [None]:
cog_category_description = pd.read_table(
    "ref/cog-20.categories.tsv",
    names=["cog_category", "description"],
    index_col="cog_category",
).assign(description=lambda x: x.index + ": " + x.description)
cog_category_description.loc["no_category", "description"] = "-: No Annotation"
cog_category_description

In [None]:
x = spgc_uhgg_class
y = gene_x_cog_category_matrix
gene_list = spgc_uhgg_class.index

cog_category_gene_class_enrichment_test = []

for _prevalence_class, _cog_category in product(
    prevalence_class_order, cog_category_order
):
    contingency_table = (
        pd.DataFrame(
            dict(
                is_prev_class=(x == _prevalence_class),
                is_cog_category=y[_cog_category],
            )
        )
        .value_counts()
        .unstack()
        .reindex(index=[False, True], columns=[False, True])
        .fillna(0)
    )
    _test = sp.stats.fisher_exact(contingency_table)
    cog_category_gene_class_enrichment_test.append(
        (
            _prevalence_class,
            _cog_category,
            _test[0],
            _test[1],
            contingency_table.loc[True, True],
        )
    )

cog_category_gene_class_enrichment_test = (
    pd.DataFrame(
        cog_category_gene_class_enrichment_test,
        columns=[
            "prevalence_class",
            "cog_category",
            "statistic",
            "pvalue",
            "gene_count",
        ],
    )
    .set_index(["prevalence_class", "cog_category"])
    .assign(
        negative_log10_pvalue=lambda x: -np.log10(x.pvalue),
        log2_odds_ratio=lambda x: np.log2(x.statistic),
    )
)

In [None]:
def _assign_significance_marker(pvalue):
    if pvalue < 1e-3:
        return "***"
    elif pvalue < 1e-2:
        return "**"
    elif pvalue < 0.05:
        return "*"
    else:
        return ""

In [None]:
x = (
    cog_category_gene_class_enrichment_test.log2_odds_ratio.unstack("prevalence_class")
    .replace({np.inf: np.nan, -np.inf: np.nan})
    .join(cog_category_description)
    .set_index("description")[prevalence_class_order]
    .fillna(0)
)
# annot = (cog_category_gene_class_enrichment_test.pvalue.map(_assign_significance_marker) + '|' + cog_category_gene_class_enrichment_test.gene_count.astype(int).astype(str)).unstack('prevalence_class')[prevalence_class_order]
annot = (
    cog_category_gene_class_enrichment_test.pvalue.map(_assign_significance_marker)
    .unstack("prevalence_class")
    .join(cog_category_description)
    .set_index("description")[prevalence_class_order]
)
# annot = cog_category_gene_class_enrichment_test.gene_count.unstack('prevalence_class')[prevalence_class_order].astype(int)

_row_order = x["core"].sort_values(ascending=False).index
# x, annot = lib.pandas_util.align_indexes(x, annot)

fig, ax = plt.subplots(figsize=(5, 12))
ax = sns.heatmap(
    x.reindex(_row_order),
    annot=annot.reindex(_row_order),
    fmt="",
    cmap="coolwarm",
    center=0,
    vmin=-5,
    vmax=5,
    cbar_kws=dict(
        use_gridspec=True, location="left", label="log2(odds ratio)", extend="both"
    ),
    ax=ax,
    yticklabels=1,
    xticklabels=1,
    annot_kws=dict(va="center"),
    # norm=mpl.colors.SymLogNorm(linthresh=1e1),
    # center=0,
)

ax.yaxis.set_ticks_position("right")
ax.set_ylabel("")
lib.plot.rotate_yticklabels(rotation=-0, va="center")

In [None]:
fig, ax = plt.subplots()

for prevalence_class in prevalence_class_order:
    d = cog_category_gene_class_enrichment_test.xs(
        prevalence_class, level="prevalence_class"
    )
    ax.scatter(
        "log2_odds_ratio",
        "negative_log10_pvalue",
        data=d,
        c=prevalence_class_palette[prevalence_class],
        label=prevalence_class,
        s=5,
    )

# Annotate top-10
for idx in (
    cog_category_gene_class_enrichment_test.sort_values(
        "negative_log10_pvalue", ascending=False
    )
    .head(10)
    .index
):
    ax.annotate(
        idx[1],
        xy=cog_category_gene_class_enrichment_test.loc[idx][
            ["log2_odds_ratio", "negative_log10_pvalue"]
        ],
        xytext=(2, 1),
        textcoords="offset pixels",
        color=prevalence_class_palette[idx[0]],
        # ha='center',
        # va='center',
        fontweight="bold",
    )
ax.axvline(0, linestyle="--", lw=1, color="k")
ax.axhline(
    -np.log10(0.05 / cog_category_gene_class_enrichment_test.shape[0]),
    linestyle="--",
    lw=1,
    color="k",
)
ax.legend()
# cog_category_gene_class_enrichment_test.sort_values('negative_log10_pvalue', ascending=False).head(20)

### Strength of Phylogenetic Signal in Shell Genes

In [None]:
shell_gene_list = idxwhere(spgc_uhgg_class == "shell")

x = strain_gene_uhgg.loc[shell_gene_list]

strain_shell_gene_cdist = sp.spatial.distance.pdist(
    x.T, metric="cosine", w=ref_gene_uhgg_entropy.loc[x.index]
)
strain_shell_gene_pdist = pd.DataFrame(
    squareform(strain_shell_gene_cdist), index=x.columns, columns=x.columns
)
strain_shell_gene_linkage = sp.cluster.hierarchy.linkage(
    strain_shell_gene_cdist, method="average", optimal_ordering=True
)

# _col_linkage = strain_shell_gene_linkage
_col_colors = pd.DataFrame(
    dict(
        genome_type=genome_type.map(genome_type_palette),
        donor_strain=x.columns.to_series()
        .map({d44_strain: "black", d97_strain: "white"})
        .fillna("grey"),
    )
)
_col_linkage = strain_shell_gene_linkage

sns.clustermap(x, col_linkage=_col_linkage, col_colors=_col_colors)

In [None]:
# FIXME: x is ambiguous here.
x = strain_shell_gene_pdist.replace({0: np.nan})
_col_linkage = strain_geno_linkage
_row_linkage = strain_shell_gene_linkage
_colors = pd.DataFrame(
    dict(
        genome_type=genome_type.map(genome_type_palette),
        donor_strain=x.columns.to_series()
        .map({d44_strain: "black", d97_strain: "white"})
        .fillna("grey"),
    )
)

g = sns.clustermap(
    x,
    col_linkage=_col_linkage,
    row_linkage=_row_linkage,
    col_colors=_colors,
    row_colors=_colors,
)
g.ax_heatmap.set_facecolor("aqua")

In [None]:
# Using all genomes, not just SPGC
phylogenetic_signal = {}
for gene_id in tqdm(idxwhere((gene_prevalence > 0.02) & (gene_prevalence < 0.98))):
    gene_pdist = pdist(strain_gene_uhgg.loc[[gene_id]].T)
    x = squareform(strain_geno_pdist)
    y = gene_pdist
    if len(x) > 1:
        phylogenetic_signal[gene_id] = sp.stats.pearsonr(x, y)

phylogenetic_signal = pd.DataFrame(
    phylogenetic_signal, index=["r_statistic", "pvalue"]
).T

In [None]:
phylogenetic_signal.sort_values("r_statistic")
plt.plot(phylogenetic_signal.r_statistic.sort_values().values)

In [None]:
x = strain_gene_uhgg.loc[shell_gene_list]  # , spgc_list]
_col_colors = pd.DataFrame(
    dict(
        genome_type=genome_type.map(genome_type_palette),
        donor_strain=x.columns.to_series()
        .map({d44_strain: "black", d97_strain: "white"})
        .fillna("grey"),
    )
)
_row_colors = pd.DataFrame(
    dict(
        phylo=phylogenetic_signal.r_statistic.map(
            lambda x: mpl.cm.coolwarm((x + 1) / 2)
        ),
        entrp=ref_gene_uhgg_entropy.map(mpl.cm.viridis),
    )
)

sns.clustermap(
    x, row_colors=_row_colors, col_linkage=strain_geno_linkage, col_colors=_col_colors
)

In [None]:
is_phylogenetic = pd.concat(
    [
        pd.Series(
            True,
            index=phylogenetic_signal[
                lambda x: (x.r_statistic > 0.1) & (x.pvalue < 0.001)
            ].index,
        ),
        pd.Series(
            False,
            index=phylogenetic_signal[
                lambda x: (x.r_statistic < 0.1) & (x.pvalue > 0.05)
            ].index,
        ),
    ]
)[lambda x: x.index.isin(shell_gene_list)]

cog_category_phylogenetic_enrichment_test = []

for _cog_category in tqdm(list(cog_category_order)):
    contingency_table = (
        pd.DataFrame(
            dict(
                is_phylogenetic=is_phylogenetic,
                is_cog_category=gene_x_cog_category_matrix[_cog_category],
            )
        )
        .dropna()
        .value_counts()
        .unstack()
        .reindex(index=[False, True], columns=[False, True])
        .fillna(0)
    )
    _test = sp.stats.fisher_exact(contingency_table)
    cog_category_phylogenetic_enrichment_test.append(
        (_cog_category, *_test, *contingency_table.stack())
    )

cog_category_phylogenetic_enrichment_test = (
    pd.DataFrame(
        cog_category_phylogenetic_enrichment_test,
        columns=["cog_category", "statistic", "pvalue", "np_nc", "np_c", "p-nc", "p-c"],
    )
    .set_index(["cog_category"])
    .assign(
        negative_log10_pvalue=lambda x: -np.log10(x.pvalue),
        log2_odds_ratio=lambda x: np.log2(x.statistic),
    )
)

In [None]:
cog_category_phylogenetic_enrichment_test.sort_values(
    "log2_odds_ratio", ascending=False
).join(cog_category_description)

In [None]:
cog_category_phylogenetic_signal_test = {}

for _cog_category in tqdm(list(cog_category_order)):
    d = pd.DataFrame(
        dict(
            phylogenetic_signal=phylogenetic_signal.r_statistic,
            is_cog_category=gene_x_cog_category_matrix.reindex(
                phylogenetic_signal.index
            ).fillna(False)[_cog_category],
        )
    )[lambda x: x.index.isin(shell_gene_list)]
    x = d[d.is_cog_category].phylogenetic_signal
    y = d[~d.is_cog_category].phylogenetic_signal
    if (len(x) > 0) and (len(y) > 0):
        _test = sp.stats.mannwhitneyu(x, y)
    else:
        _test = (np.nan, np.nan)
    cog_category_phylogenetic_signal_test[_cog_category] = (
        len(x),
        x.median(),
        y.median(),
        x.mean(),
        y.mean(),
        *_test,
    )

cog_category_phylogenetic_signal_test = pd.DataFrame(
    cog_category_phylogenetic_signal_test,
    index=[
        "num_genes_in_category",
        "cog_median_r",
        "not_cog_median_r",
        "cog_mean_r",
        "not_cog_mean_r",
        "mwu_statistic",
        "pvalue",
    ],
).T.assign(
    negative_log10_pvalue=lambda x: -np.log10(x.pvalue),
    median_diff=lambda x: x.cog_median_r - x.not_cog_median_r,
)

In [None]:
d = cog_category_phylogenetic_signal_test.join(cog_category_description)

fig, ax = plt.subplots()
plt.scatter(
    "median_diff",
    "negative_log10_pvalue",
    data=d[["description", "median_diff", "negative_log10_pvalue"]]
    .fillna(0)
    .sort_values("median_diff", ascending=False),
)
# lib.plot.rotate_xticklabels()
ax.axvline(0, linestyle="--", lw=1, color="k")
ax.axhline(-np.log10(0.05 / d.shape[0]), linestyle="--", lw=1, color="k")


for cog_category in idxwhere((np.abs(d.median_diff) > 0) & (d.pvalue < 1e-4)):
    ax.annotate(
        cog_category,
        xy=d[["median_diff", "negative_log10_pvalue"]].loc[cog_category],
        xytext=(2, 1),
        textcoords="offset pixels",
        # color=prevalence_class_palette[prevalence_class],
        # ha='center',
        # va='center',
        fontweight="bold",
    )

ax.set_ylabel("-log10(p-value)")
ax.set_xlabel("Difference in Median $r$\n(Genes in COG category vs. Others)")
ax.set_xlim(-0.35, 0.35)
ax.set_ylim(0, 20)

d.sort_values("median_diff", ascending=False)

In [None]:
d = phylogenetic_signal.r_statistic.to_frame().join(gene_x_cog_category).fillna("-")

fig, ax = plt.subplots(figsize=(20, 5))
sns.violinplot(
    data=d,
    x="cog_category",
    y="r_statistic",
    ax=ax,
    order=cog_category_phylogenetic_signal_test.cog_median_r.sort_values().index,
)

## Donor Comparison

In [None]:
x = strain_uhgg_filt_pdist.replace({0: np.nan})

_colors = pd.DataFrame(
    dict(
        genome_type=genome_type.map(genome_type_palette),
        donor_strain=x.index.to_series()
        .map(
            {d44_strain: spgc_palette[d44_strain], d97_strain: spgc_palette[d97_strain]}
        )
        .fillna("grey"),
    )
)

g = sns.clustermap(
    x,
    row_colors=_colors,
    col_colors=_colors,
    row_linkage=strain_uhgg_filt_linkage,
    col_linkage=strain_geno_linkage,
    # figsize=(40, 40),
)
g.ax_heatmap.set_facecolor("aqua")

In [None]:
d44_strain, d97_strain

In [None]:
strain_uhgg_filt[[d44_strain, d97_strain]].astype(bool).value_counts().unstack()

In [None]:
donor_strain_specificity = (
    strain_uhgg_filt[[d44_strain, d97_strain]]
    .astype(bool)
    .apply(tuple, axis=1)
    .map(
        {
            (True, True): "both",
            (True, False): "d44",
            (False, True): "d97",
            (False, False): "neither",
        }
    )
)

In [None]:
d44_only_genes = idxwhere(donor_strain_specificity == "d44")
d97_only_genes = idxwhere(donor_strain_specificity == "d97")
both_donor_strain_genes = idxwhere(donor_strain_specificity == "both")

len(d44_only_genes), len(d97_only_genes), len(both_donor_strain_genes)

In [None]:
x = donor_strain_specificity[donor_strain_specificity != "neither"]
y = gene_x_cog_category_matrix
gene_list = x.index

cog_category_donor_strain_enrichment_test = []

for _donor_strain_specificity, _cog_category in product(x.unique(), cog_category_order):
    contingency_table = (
        pd.DataFrame(
            dict(
                is_donor_strain=(x == _donor_strain_specificity),
                is_cog_category=y[_cog_category],
            )
        )
        .value_counts()
        .unstack()
        .reindex(index=[False, True], columns=[False, True])
        .fillna(0)
    )
    _test = sp.stats.fisher_exact(contingency_table)
    cog_category_donor_strain_enrichment_test.append(
        (
            _donor_strain_specificity,
            _cog_category,
            _test[0],
            _test[1],
            contingency_table.loc[True, True],
        )
    )

cog_category_donor_strain_enrichment_test = (
    pd.DataFrame(
        cog_category_donor_strain_enrichment_test,
        columns=["donor_strain", "cog_category", "statistic", "pvalue", "gene_count"],
    )
    .set_index(["donor_strain", "cog_category"])
    .assign(
        negative_log10_pvalue=lambda x: -np.log10(x.pvalue),
        log2_odds_ratio=lambda x: np.log2(x.statistic),
    )
)

In [None]:
donor_strain_order = ["d44", "both", "d97"]
x = (
    cog_category_donor_strain_enrichment_test.log2_odds_ratio.unstack("donor_strain")
    .replace({np.inf: np.nan, -np.inf: np.nan})
    .join(cog_category_description)
    .set_index("description")[donor_strain_order]
    .fillna(0)
)
# annot = (cog_category_gene_class_enrichment_test.pvalue.map(_assign_significance_marker) + '|' + cog_category_gene_class_enrichment_test.gene_count.astype(int).astype(str)).unstack('prevalence_class')[prevalence_class_order]
annot = (
    cog_category_donor_strain_enrichment_test.pvalue.map(_assign_significance_marker)
    .unstack("donor_strain")
    .join(cog_category_description)
    .set_index("description")[donor_strain_order]
)
# annot = cog_category_gene_class_enrichment_test.gene_count.unstack('prevalence_class')[prevalence_class_order].astype(int)

_row_order = x["both"].sort_values(ascending=False).index
# x, annot = lib.pandas_util.align_indexes(x, annot)

fig, ax = plt.subplots(figsize=(5, 12))
ax = sns.heatmap(
    x.reindex(_row_order),
    annot=annot.reindex(_row_order),
    fmt="",
    cmap="coolwarm",
    center=0,
    vmin=-5,
    vmax=5,
    cbar_kws=dict(
        use_gridspec=True, location="left", label="log2(odds ratio)", extend="both"
    ),
    ax=ax,
    yticklabels=1,
    xticklabels=1,
    annot_kws=dict(va="center"),
    # norm=mpl.colors.SymLogNorm(linthresh=1e1),
    # center=0,
)

ax.yaxis.set_ticks_position("right")
ax.set_ylabel("")
lib.plot.rotate_yticklabels(rotation=-0, va="center")

In [None]:
donor_strain_order = ["d44", "both", "d97"]
x = (
    cog_category_donor_strain_enrichment_test.log2_odds_ratio.unstack("donor_strain")
    .replace({np.inf: np.nan, -np.inf: np.nan})
    .join(cog_category_description)
    .set_index("description")[donor_strain_order]
    .fillna(0)
)
# annot = (cog_category_gene_class_enrichment_test.pvalue.map(_assign_significance_marker) + '|' + cog_category_gene_class_enrichment_test.gene_count.astype(int).astype(str)).unstack('prevalence_class')[prevalence_class_order]
annot = (
    cog_category_donor_strain_enrichment_test.pvalue.map(_assign_significance_marker)
    .unstack("donor_strain")
    .join(cog_category_description)
    .set_index("description")[donor_strain_order]
)
# annot = cog_category_gene_class_enrichment_test.gene_count.unstack('prevalence_class')[prevalence_class_order].astype(int)

_row_order = [
    "F: Nucleotide transport and metabolism",
    "A: RNA processing and modification",
    "J: Translation, ribosomal structure and biogenesis",
    "C: Energy production and conversion",
    "E: Amino acid transport and metabolism",
    "P: Inorganic ion transport and metabolism",
    "Z: Cytoskeleton",
    "H: Coenzyme transport and metabolism",
    "O: Post-translational modification, protein turnover, and chaperones",
    "I: Lipid transport and metabolism",
    "T: Signal transduction mechanisms",
    "Q: Secondary metabolites biosynthesis, transport, and catabolism",
    "G: Carbohydrate transport and metabolism",
    "R: General function prediction only",
    "K: Transcription",
    "M: Cell wall/membrane/envelope biogenesis",
    "B: Chromatin structure and dynamics",
    "D: Cell cycle control, cell division, chromosome partitioning",
    "V: Defense mechanisms",
    "S: Function unknown",
    "N: Cell motility",
    "L: Replication, recombination and repair",
    "U: Intracellular trafficking, secretion, and vesicular transport",
    "W: Extracellular structures",
    "-: No Annotation",
    "X: Mobilome: prophages, transposons",
]
# x, annot = lib.pandas_util.align_indexes(x, annot)

fig, ax = plt.subplots(figsize=(5, 12))
ax = sns.heatmap(
    x.reindex(_row_order),
    annot=annot.reindex(_row_order),
    fmt="",
    cmap="coolwarm",
    center=0,
    vmin=-5,
    vmax=5,
    cbar_kws=dict(
        use_gridspec=True, location="left", label="log2(odds ratio)", extend="both"
    ),
    ax=ax,
    yticklabels=1,
    xticklabels=1,
    annot_kws=dict(va="center"),
    # norm=mpl.colors.SymLogNorm(linthresh=1e1),
    # center=0,
)

ax.yaxis.set_ticks_position("right")
ax.set_ylabel("")
lib.plot.rotate_yticklabels(rotation=-0, va="center")

In [None]:
donor_strain_order = ["d44", "both", "d97"]
x = (
    cog_category_donor_strain_enrichment_test.log2_odds_ratio.unstack("donor_strain")
    .replace({np.inf: np.nan, -np.inf: np.nan})
    .join(cog_category_description)
    .set_index("description")[donor_strain_order]
    .fillna(0)
)
# annot = (cog_category_gene_class_enrichment_test.pvalue.map(_assign_significance_marker) + '|' + cog_category_gene_class_enrichment_test.gene_count.astype(int).astype(str)).unstack('prevalence_class')[prevalence_class_order]
annot = (
    donor_strain_specificity.to_frame(name="donor_strain")
    .join(gene_x_cog_category_matrix)
    .groupby("donor_strain")
    .sum()
    .T.join(cog_category_description)
    .set_index("description")[donor_strain_order]
    .fillna(0)
)
# annot = cog_category_gene_class_enrichment_test.gene_count.unstack('prevalence_class')[prevalence_class_order].astype(int)

_row_order = [
    "F: Nucleotide transport and metabolism",
    "A: RNA processing and modification",
    "J: Translation, ribosomal structure and biogenesis",
    "C: Energy production and conversion",
    "E: Amino acid transport and metabolism",
    "P: Inorganic ion transport and metabolism",
    "Z: Cytoskeleton",
    "H: Coenzyme transport and metabolism",
    "O: Post-translational modification, protein turnover, and chaperones",
    "I: Lipid transport and metabolism",
    "T: Signal transduction mechanisms",
    "Q: Secondary metabolites biosynthesis, transport, and catabolism",
    "G: Carbohydrate transport and metabolism",
    "R: General function prediction only",
    "K: Transcription",
    "M: Cell wall/membrane/envelope biogenesis",
    "B: Chromatin structure and dynamics",
    "D: Cell cycle control, cell division, chromosome partitioning",
    "V: Defense mechanisms",
    "S: Function unknown",
    "N: Cell motility",
    "L: Replication, recombination and repair",
    "U: Intracellular trafficking, secretion, and vesicular transport",
    "W: Extracellular structures",
    "-: No Annotation",
    "X: Mobilome: prophages, transposons",
]
# x, annot = lib.pandas_util.align_indexes(x, annot)

fig, ax = plt.subplots(figsize=(5, 12))
ax = sns.heatmap(
    x.reindex(_row_order),
    annot=annot.reindex(_row_order),
    fmt="",
    cmap="coolwarm",
    center=0,
    vmin=-5,
    vmax=5,
    cbar_kws=dict(
        use_gridspec=True, location="left", label="log2(odds ratio)", extend="both"
    ),
    ax=ax,
    yticklabels=1,
    xticklabels=1,
    annot_kws=dict(fontsize=12, va="center"),
    # norm=mpl.colors.SymLogNorm(linthresh=1e1),
    # center=0,
)

ax.yaxis.set_ticks_position("right")
ax.set_ylabel("")
lib.plot.rotate_yticklabels(rotation=-0, va="center")

In [None]:
d = (
    donor_strain_specificity.to_frame(name="donor_strain")
    .join(gene_x_cog_category_matrix)
    .groupby("donor_strain")
    .sum()
    .T.join(cog_category_description)
    .set_index("description")[donor_strain_order]
    .fillna(0)
)

fig, ax = plt.subplots(figsize=(5, 12))
ax = sns.heatmap(
    d.reindex(_row_order),
    fmt="",
    cmap="Oranges",
    # center=0, vmin=-5, vmax=5,
    cbar_kws=dict(use_gridspec=True, location="left", label="log2(odds ratio)"),
    ax=ax,
    yticklabels=1,
    xticklabels=1,
    annot_kws=dict(va="center"),
    norm=mpl.colors.PowerNorm(1),
    # center=0,
)

ax.yaxis.set_ticks_position("right")
ax.set_ylabel("")
lib.plot.rotate_yticklabels(rotation=-0, va="center")

In [None]:
cog_category_donor_strain_enrichment_test.xs("both", level="donor_strain").sort_values(
    "log2_odds_ratio", ascending=False
).head(50).join(cog_category_description)

## Gene Co-occurence

### Found using SPGC and Ref

In [None]:
gene_prevalence = strain_uhgg_filt[spgc_list + list(ref_list)].mean(1)
plt.hist(gene_prevalence)
plt.yscale("log")
variable_genes = idxwhere((gene_prevalence > 0.05) & (gene_prevalence < 0.95))
len(variable_genes)

In [None]:
x = strain_uhgg_filt.loc[variable_genes]  # , spgc_list]
y = x[spgc_list + list(ref_list)]
_col_colors = pd.DataFrame(
    dict(
        genome_type=genome_type.map(genome_type_palette),
        donor_strain=x.columns.to_series()
        .map({d44_strain: "black", d97_strain: "white"})
        .fillna("grey"),
        # donor_strain=x.index.to_series().map({d44_strain: 'black', d97_strain: 'white'}).fillna('grey'),
    )
)

clust2 = pd.Series(
    AgglomerativeClustering(
        n_clusters=None,
        metric="correlation",
        linkage="average",
        distance_threshold=0.035,
    ).fit_predict(y),
    index=x.index,
)
clust2_palette = lib.plot.construct_ordered_palette(
    clust2.value_counts().sample(frac=1.0).index, cm="rainbow"
)

sns.clustermap(
    x,
    row_colors=x.index.to_series().map(clust2).map(clust2_palette),
    col_linkage=strain_geno_linkage,
    col_colors=_col_colors,
)

In [None]:
clust2_sizes = clust2.value_counts()

clust2_sizes_meta = (
    clust2.to_frame("clust")
    .assign(donor=donor_strain_specificity)
    .groupby("clust")
    .donor.value_counts()
    .unstack(fill_value=0)
    .assign(sum_of_d44_d97=lambda x: x.d44 + x.d97)
    .sort_values("sum_of_d44_d97", ascending=False)
    .assign(
        top5_cc=gene_x_cog_category_matrix.groupby(clust2)
        .sum()
        .rename(int)
        .astype(int)
        .apply(lambda x: x.sort_values(ascending=False).head(5).index.to_list(), axis=1)
    )
)

clust2_sizes_meta.head(10)

In [None]:
_clust = clust2_sizes_meta.index[0]

_gene_list = idxwhere(clust2 == _clust)
x = strain_gene_uhgg.loc[_gene_list]
y = strain_gene_uhgg_depth.loc[_gene_list]
_col_colors = pd.DataFrame(
    dict(
        genome_type=genome_type.map(genome_type_palette),
        donor_strain=x.columns.to_series()
        .map({d44_strain: "black", d97_strain: "white"})
        .fillna("grey"),
        # donor_strain=x.index.to_series().map({d44_strain: 'black', d97_strain: 'white'}).fillna('grey'),
    )
)
_row_colors = (
    gene_annotations.reindex(_gene_list)
    .centroid_99_length.map(np.log10)
    .map(lambda x: x / 4)
    .map(mpl.cm.viridis)
)
_col_linkage = strain_geno_linkage
_row_linkage = sp.cluster.hierarchy.linkage(
    x, method="average", metric="cosine", optimal_ordering=True
)  # TODO

sns.clustermap(
    x,
    figsize=(10, 5),
    row_linkage=_row_linkage,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    xticklabels=0,
    yticklabels=0,
)
sns.clustermap(
    y,
    figsize=(10, 5),
    row_linkage=_row_linkage,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    norm=mpl.colors.PowerNorm(1 / 2, vmin=0, vmax=2),
    xticklabels=0,
    yticklabels=0,
)

print(
    gene_x_cog_category_matrix.reindex(_gene_list)
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
    .to_frame("tally")
    .join(cog_category_description)
)
gene_annotations.reindex(_gene_list)

In [None]:
_clust = clust2_sizes_meta.index[1]

_gene_list = idxwhere(clust2 == _clust)
x = strain_gene_uhgg.loc[_gene_list]
y = strain_gene_uhgg_depth.loc[_gene_list]
_col_colors = pd.DataFrame(
    dict(
        genome_type=genome_type.map(genome_type_palette),
        donor_strain=x.columns.to_series()
        .map({d44_strain: "black", d97_strain: "white"})
        .fillna("grey"),
    )
)
_row_colors = (
    gene_annotations.reindex(_gene_list)
    .centroid_99_length.map(np.log10)
    .map(lambda x: x / 4)
    .map(mpl.cm.viridis)
)
_col_linkage = strain_geno_linkage
_row_linkage = sp.cluster.hierarchy.linkage(
    x, method="average", metric="cosine", optimal_ordering=True
)  # TODO

sns.clustermap(
    x,
    figsize=(10, 5),
    row_linkage=_row_linkage,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    xticklabels=0,
    yticklabels=0,
)
sns.clustermap(
    y,
    figsize=(10, 5),
    row_linkage=_row_linkage,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    norm=mpl.colors.PowerNorm(1 / 2, vmin=0, vmax=2),
    xticklabels=0,
    yticklabels=0,
)

print(
    gene_x_cog_category_matrix.reindex(_gene_list)
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
    .to_frame("tally")
    .join(cog_category_description)
)
gene_annotations.reindex(_gene_list)

In [None]:
_clust = clust2_sizes_meta.index[2]

_gene_list = idxwhere(clust2 == _clust)
x = strain_gene_uhgg.loc[_gene_list]
y = strain_gene_uhgg_depth.loc[_gene_list]
_col_colors = pd.DataFrame(
    dict(
        genome_type=genome_type.map(genome_type_palette),
        donor_strain=x.columns.to_series()
        .map({d44_strain: "black", d97_strain: "white"})
        .fillna("grey"),
    )
)
_row_colors = (
    gene_annotations.reindex(_gene_list)
    .centroid_99_length.map(np.log10)
    .map(lambda x: x / 4)
    .map(mpl.cm.viridis)
)
_col_linkage = strain_geno_linkage
_row_linkage = sp.cluster.hierarchy.linkage(
    x, method="average", metric="cosine", optimal_ordering=True
)  # TODO

sns.clustermap(
    x,
    figsize=(10, 5),
    row_linkage=_row_linkage,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    xticklabels=0,
    yticklabels=0,
)
sns.clustermap(
    y,
    figsize=(10, 5),
    row_linkage=_row_linkage,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    norm=mpl.colors.PowerNorm(1 / 2, vmin=0, vmax=2),
    xticklabels=0,
    yticklabels=0,
)

print(
    gene_x_cog_category_matrix.reindex(_gene_list)
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
    .to_frame("tally")
    .join(cog_category_description)
)
gene_annotations.reindex(_gene_list)

In [None]:
_clust = clust2_sizes_meta.index[3]

_gene_list = idxwhere(clust2 == _clust)
x = strain_gene_uhgg.loc[_gene_list]
y = strain_gene_uhgg_depth.loc[_gene_list]
_col_colors = pd.DataFrame(
    dict(
        genome_type=genome_type.map(genome_type_palette),
        donor_strain=x.columns.to_series()
        .map({d44_strain: "black", d97_strain: "white"})
        .fillna("grey"),
    )
)
_row_colors = (
    gene_annotations.reindex(_gene_list)
    .centroid_99_length.map(np.log10)
    .map(lambda x: x / 4)
    .map(mpl.cm.viridis)
)
_col_linkage = strain_geno_linkage
_row_linkage = sp.cluster.hierarchy.linkage(
    x, method="average", metric="cosine", optimal_ordering=True
)  # TODO

sns.clustermap(
    x,
    figsize=(10, 5),
    row_linkage=_row_linkage,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    xticklabels=0,
    yticklabels=0,
)
sns.clustermap(
    y,
    figsize=(10, 5),
    row_linkage=_row_linkage,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    norm=mpl.colors.PowerNorm(1 / 2, vmin=0, vmax=2),
    xticklabels=0,
    yticklabels=0,
)

print(
    gene_x_cog_category_matrix.reindex(_gene_list)
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
    .to_frame("tally")
    .join(cog_category_description)
)
gene_annotations.reindex(_gene_list)

In [None]:
_clust = clust2_sizes_meta.index[4]

_gene_list = idxwhere(clust2 == _clust)
x = strain_gene_uhgg.loc[_gene_list]
y = strain_gene_uhgg_depth.loc[_gene_list]
_col_colors = pd.DataFrame(
    dict(
        genome_type=genome_type.map(genome_type_palette),
        donor_strain=x.columns.to_series()
        .map({d44_strain: "black", d97_strain: "white"})
        .fillna("grey"),
    )
)
_row_colors = (
    gene_annotations.reindex(_gene_list)
    .centroid_99_length.map(np.log10)
    .map(lambda x: x / 4)
    .map(mpl.cm.viridis)
)
_col_linkage = strain_geno_linkage
_row_linkage = sp.cluster.hierarchy.linkage(
    x, method="average", metric="cosine", optimal_ordering=True
)  # TODO

sns.clustermap(
    x,
    figsize=(10, 5),
    row_linkage=_row_linkage,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    xticklabels=0,
    yticklabels=0,
)
sns.clustermap(
    y,
    figsize=(10, 5),
    row_linkage=_row_linkage,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    norm=mpl.colors.PowerNorm(1 / 2, vmin=0, vmax=2),
    xticklabels=0,
    yticklabels=0,
)

print(
    gene_x_cog_category_matrix.reindex(_gene_list)
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
    .to_frame("tally")
    .join(cog_category_description)
)
gene_annotations.reindex(_gene_list)

In [None]:
_clust = clust2_sizes_meta.index[5]

_gene_list = idxwhere(clust2 == _clust)
x = strain_gene_uhgg.loc[_gene_list]
y = strain_gene_uhgg_depth.loc[_gene_list]
_col_colors = pd.DataFrame(
    dict(
        genome_type=genome_type.map(genome_type_palette),
        donor_strain=x.index.to_series()
        .map({d44_strain: "black", d97_strain: "white"})
        .fillna("grey"),
    )
)
_row_colors = (
    gene_annotations.reindex(_gene_list)
    .centroid_99_length.map(np.log10)
    .map(lambda x: x / 4)
    .map(mpl.cm.viridis)
)
_col_linkage = strain_geno_linkage
_row_linkage = sp.cluster.hierarchy.linkage(
    x, method="average", metric="cosine", optimal_ordering=True
)  # TODO

sns.clustermap(
    x,
    figsize=(10, 5),
    row_linkage=_row_linkage,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    xticklabels=0,
    yticklabels=0,
)
sns.clustermap(
    y,
    figsize=(10, 5),
    row_linkage=_row_linkage,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    norm=mpl.colors.PowerNorm(1 / 2, vmin=0, vmax=2),
    xticklabels=0,
    yticklabels=0,
)

print(
    gene_x_cog_category_matrix.reindex(_gene_list)
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
    .to_frame("tally")
    .join(cog_category_description)
)
gene_annotations.reindex(_gene_list)

In [None]:
_clust = clust2_sizes_meta.index[6]

_gene_list = idxwhere(clust2 == _clust)
x = strain_gene_uhgg.loc[_gene_list]
y = strain_gene_uhgg_depth.loc[_gene_list]
_col_colors = pd.DataFrame(
    dict(
        genome_type=genome_type.map(genome_type_palette),
        donor_strain=x.columns.to_series()
        .map({d44_strain: "black", d97_strain: "white"})
        .fillna("grey"),
    )
)
_row_colors = (
    gene_annotations.reindex(_gene_list)
    .centroid_99_length.map(np.log10)
    .map(lambda x: x / 4)
    .map(mpl.cm.viridis)
)
_col_linkage = strain_geno_linkage
_row_linkage = sp.cluster.hierarchy.linkage(
    x, method="average", metric="cosine", optimal_ordering=True
)  # TODO

sns.clustermap(
    x,
    figsize=(10, 5),
    row_linkage=_row_linkage,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    xticklabels=0,
    yticklabels=0,
)
sns.clustermap(
    y,
    figsize=(10, 5),
    row_linkage=_row_linkage,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    norm=mpl.colors.PowerNorm(1 / 2, vmin=0, vmax=2),
    xticklabels=0,
    yticklabels=0,
)

print(
    gene_x_cog_category_matrix.reindex(_gene_list)
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
    .to_frame("tally")
    .join(cog_category_description)
)
gene_annotations.reindex(_gene_list)

#### Detailed Donor Comparison of Clusters

In [None]:
cog_category_description

In [None]:
donor_strain_specificity

In [None]:
# Choose clusters that are only found in one of d44 or d97
_clust_list = []
_clust_list += (
    clust2_sizes_meta.d44.sort_values(ascending=False).head(10).index.to_list()
)
_clust_list += (
    clust2_sizes_meta.d97.sort_values(ascending=False).head(10).index.to_list()
)

_gene_list = idxwhere(clust2.isin(_clust_list))
x = strain_gene_uhgg.loc[_gene_list]
y = strain_gene_uhgg_depth.loc[_gene_list]
_col_colors = pd.DataFrame(
    dict(
        genome_type=genome_type.map(genome_type_palette),
        donor_strain=x.columns.to_series()
        .map({d44_strain: "black", d97_strain: "white"})
        .fillna("grey"),
        # donor_strain=x.index.to_series().map({d44_strain: 'black', d97_strain: 'white'}).fillna('grey'),
    )
)
_row_colors = pd.DataFrame(
    dict(
        # length=gene_annotations.reindex(_gene_list).centroid_99_length.map(np.log10).map(lambda x: x / 4).map(mpl.cm.viridis),
        phylo=phylogenetic_signal.r_statistic.pipe(
            lambda x: x - x.min() / (x.max() - x.min())
        ).map(mpl.cm.viridis),
        x=gene_x_cog_category_matrix["X"].map(
            {True: "black", False: "white"}
        ),  # Mobilome
        m=gene_x_cog_category_matrix["M"].map(
            {True: "black", False: "white"}
        ),  # Cell wall
        n=gene_x_cog_category_matrix["N"].map(
            {True: "black", False: "white"}
        ),  # Motility
        w=gene_x_cog_category_matrix["W"].map(
            {True: "black", False: "white"}
        ),  # Extracellular struct.
        u=gene_x_cog_category_matrix["U"].map(
            {True: "black", False: "white"}
        ),  # Trafficing and secretion
        clust=clust2.map(clust2_palette),
        donor_strain=donor_strain_specificity.map(
            {
                "d44": "tab:blue",
                "d97": "tab:green",
                "both": "tab:purple",
                "neither": "grey",
            }
        ),
    )
)
_col_linkage = strain_geno_linkage
_row_linkage = sp.cluster.hierarchy.linkage(
    x, method="average", metric="cosine", optimal_ordering=True
)  # TODO

sns.clustermap(
    x,
    figsize=(10, 10),
    row_linkage=_row_linkage,
    row_colors=_row_colors,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    xticklabels=0,
    yticklabels=0,
)
# sns.clustermap(y, figsize=(10, 5), row_linkage=_row_linkage, col_linkage=_col_linkage, col_colors=_col_colors, norm=mpl.colors.PowerNorm(1/2, vmin=0, vmax=2), xticklabels=0, yticklabels=0)

print(
    gene_x_cog_category_matrix.reindex(_gene_list)
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
    .to_frame("tally")
    .join(cog_category_description)
)
# gene_annotations.reindex(_gene_list)

In [None]:
donor_strain_specificity

### Found using SPGC Only

In [None]:
spgc_gene_prevalence = strain_uhgg_filt[spgc_list].mean(1)
plt.hist(spgc_gene_prevalence)
plt.yscale("log")
spgc_variable_genes = idxwhere(
    (spgc_gene_prevalence > 0.1) & (spgc_gene_prevalence < 0.9)
)
len(spgc_variable_genes)

In [None]:
x = strain_uhgg_filt.loc[spgc_variable_genes]  # , spgc_list]
y = x[spgc_list]
_col_colors = pd.DataFrame(
    dict(
        genome_type=genome_type.map(genome_type_palette),
        donor_strain=x.columns.to_series()
        .map({d44_strain: "black", d97_strain: "white"})
        .fillna("grey"),
    )
)

clust = pd.Series(
    AgglomerativeClustering(
        n_clusters=None,
        metric="correlation",
        linkage="average",
        distance_threshold=0.035,
    ).fit_predict(y),
    index=x.index,
)
clust_palette = lib.plot.construct_ordered_palette(clust.unique())

sns.clustermap(
    x,
    row_colors=x.index.to_series().map(clust).map(clust_palette),
    col_linkage=strain_geno_linkage,
    col_colors=_col_colors,
)

In [None]:
clust_sizes = clust.value_counts()

clust_sizes_meta = (
    clust.to_frame("clust")
    .assign(donor=donor_strain_specificity)
    .groupby("clust")
    .donor.value_counts()
    .unstack(fill_value=0)
    .assign(sum_of_d44_d97=lambda x: x.d44 + x.d97)
    .sort_values("sum_of_d44_d97", ascending=False)
    .assign(
        top5_cc=gene_x_cog_category_matrix.groupby(clust)
        .sum()
        .rename(int)
        .astype(int)
        .apply(lambda x: x.sort_values(ascending=False).head(5).index.to_list(), axis=1)
    )
)

clust_sizes_meta.head(10)

In [None]:
_clust = clust_sizes_meta.index[0]

_gene_list = idxwhere(clust == _clust)
x = strain_gene_uhgg.loc[_gene_list]
y = strain_gene_uhgg_depth.loc[_gene_list]
_col_colors = pd.DataFrame(
    dict(
        genome_type=genome_type.map(genome_type_palette),
        donor_strain=x.columns.to_series()
        .map({d44_strain: "black", d97_strain: "white"})
        .fillna("grey"),
    )
)
_row_colors = (
    gene_annotations.reindex(_gene_list)
    .centroid_99_length.map(np.log10)
    .map(lambda x: x / 4)
    .map(mpl.cm.viridis)
)
_col_linkage = strain_geno_linkage
_row_linkage = sp.cluster.hierarchy.linkage(
    x, method="average", metric="cosine", optimal_ordering=True
)  # TODO

sns.clustermap(
    x,
    figsize=(10, 5),
    row_linkage=_row_linkage,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    xticklabels=0,
    yticklabels=0,
)
sns.clustermap(
    y,
    figsize=(10, 5),
    row_linkage=_row_linkage,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    norm=mpl.colors.PowerNorm(1 / 2, vmin=0, vmax=2),
    xticklabels=0,
    yticklabels=0,
)

print(
    gene_x_cog_category_matrix.reindex(_gene_list)
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
    .to_frame("tally")
    .join(cog_category_description)
)
gene_annotations.reindex(_gene_list)

In [None]:
_clust = clust_sizes_meta.index[1]

_gene_list = idxwhere(clust == _clust)
x = strain_gene_uhgg.loc[_gene_list]
y = strain_gene_uhgg_depth.loc[_gene_list]
_col_colors = pd.DataFrame(
    dict(
        genome_type=genome_type.map(genome_type_palette),
        donor_strain=x.columns.to_series()
        .map({d44_strain: "black", d97_strain: "white"})
        .fillna("grey"),
    )
)
_row_colors = (
    gene_annotations.reindex(_gene_list)
    .centroid_99_length.map(np.log10)
    .map(lambda x: x / 4)
    .map(mpl.cm.viridis)
)
_col_linkage = strain_geno_linkage
_row_linkage = sp.cluster.hierarchy.linkage(
    x, method="average", metric="cosine", optimal_ordering=True
)  # TODO

sns.clustermap(
    x,
    figsize=(10, 5),
    row_linkage=_row_linkage,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    xticklabels=0,
    yticklabels=0,
)
sns.clustermap(
    y,
    figsize=(10, 5),
    row_linkage=_row_linkage,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    norm=mpl.colors.PowerNorm(1 / 2, vmin=0, vmax=2),
    xticklabels=0,
    yticklabels=0,
)

print(
    gene_x_cog_category_matrix.reindex(_gene_list)
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
    .to_frame("tally")
    .join(cog_category_description)
)
gene_annotations.reindex(_gene_list)

In [None]:
_clust = clust_sizes_meta.index[2]

_gene_list = idxwhere(clust == _clust)
x = strain_gene_uhgg.loc[_gene_list]
y = strain_gene_uhgg_depth.loc[_gene_list]
_col_colors = pd.DataFrame(
    dict(
        genome_type=genome_type.map(genome_type_palette),
        donor_strain=x.columns.to_series()
        .map({d44_strain: "black", d97_strain: "white"})
        .fillna("grey"),
    )
)
_row_colors = (
    gene_annotations.reindex(_gene_list)
    .centroid_99_length.map(np.log10)
    .map(lambda x: x / 4)
    .map(mpl.cm.viridis)
)
_col_linkage = strain_geno_linkage
_row_linkage = sp.cluster.hierarchy.linkage(
    x, method="average", metric="cosine", optimal_ordering=True
)  # TODO

sns.clustermap(
    x,
    figsize=(10, 5),
    row_linkage=_row_linkage,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    xticklabels=0,
    yticklabels=0,
)
sns.clustermap(
    y,
    figsize=(10, 5),
    row_linkage=_row_linkage,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    norm=mpl.colors.PowerNorm(1 / 2, vmin=0, vmax=2),
    xticklabels=0,
    yticklabels=0,
)

print(
    gene_x_cog_category_matrix.reindex(_gene_list)
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
    .to_frame("tally")
    .join(cog_category_description)
)
gene_annotations.reindex(_gene_list)

In [None]:
_clust = clust_sizes_meta.index[3]

_gene_list = idxwhere(clust == _clust)
x = strain_gene_uhgg.loc[_gene_list]
y = strain_gene_uhgg_depth.loc[_gene_list]
_col_colors = pd.DataFrame(
    dict(
        genome_type=genome_type.map(genome_type_palette),
        donor_strain=x.columns.to_series()
        .map({d44_strain: "black", d97_strain: "white"})
        .fillna("grey"),
    )
)
_row_colors = (
    gene_annotations.reindex(_gene_list)
    .centroid_99_length.map(np.log10)
    .map(lambda x: x / 4)
    .map(mpl.cm.viridis)
)
_col_linkage = strain_geno_linkage
_row_linkage = sp.cluster.hierarchy.linkage(
    x, method="average", metric="cosine", optimal_ordering=True
)  # TODO

sns.clustermap(
    x,
    figsize=(10, 5),
    row_linkage=_row_linkage,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    xticklabels=0,
    yticklabels=0,
)
sns.clustermap(
    y,
    figsize=(10, 5),
    row_linkage=_row_linkage,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    norm=mpl.colors.PowerNorm(1 / 2, vmin=0, vmax=2),
    xticklabels=0,
    yticklabels=0,
)

print(
    gene_x_cog_category_matrix.reindex(_gene_list)
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
    .to_frame("tally")
    .join(cog_category_description)
)
gene_annotations.reindex(_gene_list)

In [None]:
_clust = clust_sizes_meta.index[4]

_gene_list = idxwhere(clust == _clust)
x = strain_gene_uhgg.loc[_gene_list]
y = strain_gene_uhgg_depth.loc[_gene_list]
_col_colors = pd.DataFrame(
    dict(
        genome_type=genome_type.map(genome_type_palette),
        donor_strain=x.columns.to_series()
        .map({d44_strain: "black", d97_strain: "white"})
        .fillna("grey"),
    )
)
_row_colors = (
    gene_annotations.reindex(_gene_list)
    .centroid_99_length.map(np.log10)
    .map(lambda x: x / 4)
    .map(mpl.cm.viridis)
)
_col_linkage = strain_geno_linkage
_row_linkage = sp.cluster.hierarchy.linkage(
    x, method="average", metric="cosine", optimal_ordering=True
)  # TODO

sns.clustermap(
    x,
    figsize=(10, 5),
    row_linkage=_row_linkage,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    xticklabels=0,
    yticklabels=0,
)
sns.clustermap(
    y,
    figsize=(10, 5),
    row_linkage=_row_linkage,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    norm=mpl.colors.PowerNorm(1 / 2, vmin=0, vmax=2),
    xticklabels=0,
    yticklabels=0,
)

print(
    gene_x_cog_category_matrix.reindex(_gene_list)
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
    .to_frame("tally")
    .join(cog_category_description)
)
gene_annotations.reindex(_gene_list)

In [None]:
_clust = clust_sizes_meta.index[5]

_gene_list = idxwhere(clust == _clust)
x = strain_gene_uhgg.loc[_gene_list]
y = strain_gene_uhgg_depth.loc[_gene_list]
_col_colors = pd.DataFrame(
    dict(
        genome_type=genome_type.map(genome_type_palette),
        donor_strain=x.columns.to_series()
        .map({d44_strain: "black", d97_strain: "white"})
        .fillna("grey"),
    )
)
_row_colors = (
    gene_annotations.reindex(_gene_list)
    .centroid_99_length.map(np.log10)
    .map(lambda x: x / 4)
    .map(mpl.cm.viridis)
)
_col_linkage = strain_geno_linkage
_row_linkage = sp.cluster.hierarchy.linkage(
    x, method="average", metric="cosine", optimal_ordering=True
)  # TODO

sns.clustermap(
    x,
    figsize=(10, 5),
    row_linkage=_row_linkage,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    xticklabels=0,
    yticklabels=0,
)
sns.clustermap(
    y,
    figsize=(10, 5),
    row_linkage=_row_linkage,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    norm=mpl.colors.PowerNorm(1 / 2, vmin=0, vmax=2),
    xticklabels=0,
    yticklabels=0,
)

print(
    gene_x_cog_category_matrix.reindex(_gene_list)
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
    .to_frame("tally")
    .join(cog_category_description)
)
gene_annotations.reindex(_gene_list)

In [None]:
_clust = clust_sizes_meta.index[6]

_gene_list = idxwhere(clust == _clust)
x = strain_gene_uhgg.loc[_gene_list]
y = strain_gene_uhgg_depth.loc[_gene_list]
_col_colors = pd.DataFrame(
    dict(
        genome_type=genome_type.map(genome_type_palette),
        donor_strain=x.columns.to_series()
        .map({d44_strain: "black", d97_strain: "white"})
        .fillna("grey"),
    )
)
_row_colors = (
    gene_annotations.reindex(_gene_list)
    .centroid_99_length.map(np.log10)
    .map(lambda x: x / 4)
    .map(mpl.cm.viridis)
)
_col_linkage = strain_geno_linkage
_row_linkage = sp.cluster.hierarchy.linkage(
    x, method="average", metric="cosine", optimal_ordering=True
)  # TODO

sns.clustermap(
    x,
    figsize=(10, 5),
    row_linkage=_row_linkage,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    xticklabels=0,
    yticklabels=0,
)
sns.clustermap(
    y,
    figsize=(10, 5),
    row_linkage=_row_linkage,
    col_linkage=_col_linkage,
    col_colors=_col_colors,
    norm=mpl.colors.PowerNorm(1 / 2, vmin=0, vmax=2),
    xticklabels=0,
    yticklabels=0,
)

print(
    gene_x_cog_category_matrix.reindex(_gene_list)
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
    .to_frame("tally")
    .join(cog_category_description)
)
gene_annotations.reindex(_gene_list)