## Preamble

### Project Template

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

### Imports

In [None]:
import os
import subprocess
import sys
import time
from datetime import datetime
from glob import glob
from itertools import chain, product
from tempfile import mkstemp

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import sfacts as sf
import statsmodels.api as sm
import statsmodels.formula.api as smf
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
from lib.dissimilarity import load_dmat_as_pickle
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

In [None]:
import lib.thisproject.data

### Set Style

In [None]:
sns.set_context("paper")
plt.rcParams["figure.dpi"] = 300
# NOTE: You can use the Chrome page inspector to make these large, presentable PNGs fit inside your notebook effectively, too.

# FIXME: Trying to make figures with Helvetica or Arial as the font, but apparently these aren't in available locations...?
plt.rcParams["font.sans-serif"] = [
    # "Helvetica",
    "DejaVu Sans",
    # "Bitstream Vera Sans",
    # "Computer Modern Sans Serif",
    # "Lucida Grande",
    # "Verdana",
    # "Geneva",
    # "Lucid",
    # "Arial",
    # "Avant Garde",
    # "sans-serif",
]

## Metadata

In [None]:
species_list = (
    pd.read_table("meta/species_group.tsv")[lambda x: x.species_group_id == "hmp2"]
    .species_id.astype(str)
    .unique()
)

In [None]:
def parse_taxonomy_string(taxonomy_string):
    values = taxonomy_string.split(";")
    return pd.Series(values, index=["d__", "p__", "c__", "o__", "f__", "g__", "s__"])

In [None]:
species_taxonomy_inpath = "ref/uhgg_genomes_all_v2.tsv"

species_taxonomy = (
    pd.read_table(species_taxonomy_inpath)[lambda x: x.Genome == x.Species_rep]
    .assign(species_id=lambda x: "1" + x.MGnify_accession.str.split("-").str[2])
    .set_index("species_id")
    .Lineage.apply(parse_taxonomy_string)
)
species_taxonomy

In [None]:
phylum_order = [
    "p__Euryarchaeota",
    "p__Thermoplasmatota",
    "p__Firmicutes",
    "p__Firmicutes_A",
    "p__Firmicutes_C",
    # "p__Firmicutes_B", # None in species_list1
    # "p__Firmicutes_G", # B/G/I not sure how related to C or A
    # "p__Firmicutes_I", #
    # "p__Cyanobacteria", # None in species_list1
    "p__Actinobacteriota",
    "p__Synergistota",
    "p__Fusobacteriota",
    "p__Campylobacterota",
    "p__Proteobacteria",
    "p__Desulfobacterota_A",
    "p__Bacteroidota",
    "p__Verrucomicrobiota",
    # "dummy0", # 18
    # "dummy1", # 19
    # "dummy2", # 20
]

phylum_palette = lib.plot.construct_ordered_palette(
    phylum_order,
    cm="rainbow",
    desaturate_levels=[1.0, 0.5],
)

for p__ in phylum_order:
    print(p__, phylum_palette[p__])
    plt.scatter([], [], color=phylum_palette[p__], label=p__)
plt.legend(ncols=4)
lib.plot.hide_axes_and_spines()

# assert len(set(phylum_palette.values())) == len((phylum_palette.values()))

In [None]:
def classify_genome(x):
    if (x.genome_type == "Isolate") & x.passes_filter:
        return "isolate"
    elif (x.genome_type == "Isolate") & ~x.passes_filter:
        return "isolate_fails_qc"
    elif (x.genome_type == "MAG") & x.passes_filter:
        return "mag"
    elif (x.genome_type == "MAG") & ~x.passes_filter:
        return "mag_fails_qc"
    elif (x.genome_type == "SPGC") & x.passes_filter:
        return "spgc"
    elif (x.genome_type == "SPGC") & x.passes_geno_positions:
        return "sfacts_only"
    elif (x.genome_type == "SPGC") & ~(x.passes_geno_positions):
        return "sfacts_fails_qc"
    else:
        raise ValueError("Genome did not match classification criteria:", x)

In [None]:
filt_stats = []
missing_species = []

_species_list = species_list
# _species_list = ["100003"]

for species in tqdm(_species_list):
    inpath = f"data/group/hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_v15-v22-agg75.spgc-fit.strain_meta_spgc_and_ref.tsv"
    if not os.path.exists(inpath):
        missing_species.append(species)
        continue
    data = pd.read_table(inpath).assign(species=species, inpath=inpath)
    filt_stats.append(data)
filt_stats = (
    pd.concat(filt_stats)
    .assign(
        genome_class=lambda x: x.apply(classify_genome, axis=1),
        species_strain=lambda x: x.species + "_" + x.genome_id,
    )
    .set_index("species_strain")
)


print(
    len(missing_species),
    "out of",
    len(_species_list),
    "species are missing stats.",
)

In [None]:
# Define different subsets of the species:

# All species:
# species_list
spgc_strain_list = filt_stats[lambda x: x.genome_type.isin(["SPGC"])].index.values

# All species with enough positions
species_list0 = filt_stats[lambda x: x.passes_geno_positions].species.unique()
spgc_strain_list0 = filt_stats[
    lambda x: x.passes_geno_positions
    & x.genome_type.isin(["SPGC"])
    & x.species.isin(species_list0)
].index.values

# All species with sf strains
species_list1 = filt_stats[
    lambda x: x.passes_geno_positions & x.genome_type.isin(["SPGC"])
].species.unique()
spgc_strain_list1 = filt_stats[
    lambda x: x.passes_geno_positions
    & x.genome_type.isin(["SPGC"])
    & x.species.isin(species_list1)
].index.values

# All species with sf strains to talk about distributions (>=10)
species_list1b = idxwhere(
    filt_stats[
        lambda x: x.passes_geno_positions & x.genome_type.isin(["SPGC"])
    ].species.value_counts()
    >= 10
)
spgc_strain_list1b = filt_stats[
    lambda x: x.passes_geno_positions
    & x.genome_type.isin(["SPGC"])
    & x.species.isin(species_list1b)
].index.values

# All species with spgc strains
species_list2 = filt_stats[
    lambda x: x.passes_filter & x.genome_type.isin(["SPGC"])
].species.unique()
spgc_strain_list2 = filt_stats[
    lambda x: x.passes_filter
    & x.genome_type.isin(["SPGC"])
    & x.species.isin(species_list2)
].index.values

# All species with enough spgc strains for pangenome analysis (>=10)
species_list3 = idxwhere(
    filt_stats[
        lambda x: x.passes_filter & x.genome_type.isin(["SPGC"])
    ].species.value_counts()
    >= 10
)
spgc_strain_list3 = filt_stats[
    lambda x: x.passes_filter
    & x.genome_type.isin(["SPGC"])
    & x.species.isin(species_list3)
].index.values

# Species with large numbers of strains (>=20)
species_list4 = idxwhere(
    filt_stats[
        lambda x: x.passes_filter & x.genome_type.isin(["SPGC"])
    ].species.value_counts()
    >= 20
)
spgc_strain_list4 = filt_stats[
    lambda x: x.passes_filter
    & x.genome_type.isin(["SPGC"])
    & x.species.isin(species_list4)
].index.values

_species_list_map = {
    "All considered species": (species_list, spgc_strain_list),
    "0: Species with enough genotyped positions": (species_list0, spgc_strain_list0),
    "1: With sfacts strains": (species_list1, spgc_strain_list1),
    "1b: With (>=10) sfacts strains": (species_list1b, spgc_strain_list1b),
    "2: With SPGC inferences": (species_list2, spgc_strain_list2),
    "3: With >=10 inferences": (species_list3, spgc_strain_list3),
    "4: With >=20 inferences": (species_list4, spgc_strain_list4),
}
for _species_list_name, (_species_list, _strain_list) in _species_list_map.items():
    print(_species_list_name, len(_species_list), len(_strain_list))
    print(species_taxonomy.loc[_species_list].p__.value_counts())
    print()

## Analysis

## Higher rate of same-categories and same-module (plasmids?) matching in clusters

In [None]:
# Gather clusters for all species (among a particular species list, that is)
# Gather annotations for all genes in these clusters
# Put everything together and do the same analysis as above

In [None]:
clust = []
kegg_module = []
cog_category = []
plasmid = []
phage = []

for species in tqdm(species_list4):
    clust_inpath = f"data/group/hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_v15-v22-agg75.spgc-fit.uhgg-strain_gene.gene_clust-t10.tsv"
    kegg_module_inpath = (
        f"data/species/sp-{species}/midasdb_v15.emapper.gene75_x_kegg_module.tsv"
    )
    cog_category_inpath = (
        f"data/species/sp-{species}/midasdb_v15.emapper.gene75_x_cog_category.tsv"
    )
    plasmid_inpath = (
        f"data/species/sp-{species}/midasdb_v15.gene75_x_genomad_plasmid.tsv"
    )
    phage_inpath = f"data/species/sp-{species}/midasdb_v15.gene75_x_genomad_virus.tsv"
    clust.append(
        pd.read_table(clust_inpath, names=["centroid_75", "clust"]).assign(
            species=species
        )
    )
    kegg_module.append(pd.read_table(kegg_module_inpath))
    cog_category.append(pd.read_table(cog_category_inpath))
    plasmid.append(pd.read_table(plasmid_inpath))
    phage.append(pd.read_table(phage_inpath))

clust = pd.concat(clust).assign(
    clust_label=lambda x: x.clust.astype(str) + "_" + x.species
)
kegg_module = pd.concat(kegg_module)
cog_category = pd.concat(cog_category)
plasmid = pd.concat(plasmid)
phage = pd.concat(phage)

In [None]:
# How many clusters (with >=2 genes)?
clust[lambda x: (x.clust >= 0)].clust_label.value_counts()[lambda x: x > 1]

In [None]:
# Median number of clusters per species.
clust[lambda x: (x.clust >= 0)].groupby("species").clust_label.value_counts()[
    lambda x: x > 1
].reset_index().species.value_counts().median()

In [None]:
bins = np.array([2**i for i in range(1, 11)])

d = (
    clust[lambda x: x.clust >= 0]
    .groupby("species")
    .apply(
        lambda d: pd.Series(
            np.histogram(d.clust_label.value_counts(), bins=bins)[0], index=bins[:-1]
        )
    )
)
# Order by total number of genes in clusters with size ≥2
species_order = (
    (
        clust[lambda x: x.clust >= 0][["species", "clust"]].value_counts()[
            lambda x: x > 1
        ]
    )
    .groupby("species")
    .sum()
    .sort_values(ascending=False)
    .index
)

row_colors = (
    d.index.to_series().map(species_taxonomy.p__).map(phylum_palette)#.rename("")
)

norm = mpl.colors.SymLogNorm(1, vmin=0, vmax=2_000)
cmap = sns.color_palette("rocket", as_cmap=True)

cg = sns.clustermap(
    d.loc[species_order].drop(columns=[1], errors="ignore"),
    norm=norm,
    cmap=cmap,
    xticklabels=1,
    yticklabels=0,
    col_cluster=False,
    row_cluster=False,
    row_colors=row_colors,
    figsize=(4.5, 3.5),
    cbar_pos=None,
)
cg.ax_heatmap.set_ylabel("")
cg.figure.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), ax=cg.ax_heatmap)
cg.ax_heatmap.set_xticks(
    ticks=np.array(cg.ax_heatmap.get_xticks()) + 0.45, labels=d.columns, rotation=0
)

# fig, ax = plt.subplots(figsize=(0.15, 1.5), facecolor='none')
# fig.colorbar(
# # lib.plot.hide_axes_and_spines()

In [None]:
bins = np.array([2**i for i in range(1, 11)])

d = (
    clust[lambda x: x.clust >= 0]
    .groupby("species")
    .apply(
        lambda d: pd.Series(
            np.histogram(d.clust_label.value_counts(), bins=bins)[0], index=bins[:-1]
        )
    )
)
# Order by total number of genes in clusters with size ≥2
species_order = (
    (
        clust[lambda x: x.clust >= 0][["species", "clust"]].value_counts()[
            lambda x: x > 1
        ]
    )
    .groupby("species")
    .sum()
    .sort_values(ascending=False)
    .index
)

row_colors = (
    d.index.to_series().map(species_taxonomy.p__).map(phylum_palette).rename("")
)

norm = mpl.colors.SymLogNorm(1, vmin=0, vmax=2_000)
cmap = sns.color_palette("rocket", as_cmap=True)

cg = sns.clustermap(
    d.loc[species_order].drop(columns=[1], errors="ignore").T,
    norm=norm,
    cmap=cmap,
    xticklabels=0,
    yticklabels=1,
    col_cluster=False,
    row_cluster=False,
    col_colors=row_colors,
    figsize=(6, 3),
    cbar_pos=None,
)
cg.ax_heatmap.set_yticks(
    ticks=np.array(cg.ax_heatmap.get_yticks()) - 0.45, labels=d.columns, rotation=0
)
# lib.plot.rotate_yticklabels(ax=cg.ax_heatmap, rotation=0, va='center')

fig, ax = plt.subplots(figsize=(0.15, 1.5), facecolor="none")
fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), cax=ax)
# lib.plot.hide_axes_and_spines()

fig, ax = plt.subplots(figsize=(1.5, 0.15), facecolor="none")
fig.colorbar(
    mpl.cm.ScalarMappable(norm=norm, cmap=cmap), cax=ax, orientation="horizontal"
)
# lib.plot.hide_axes_and_spines()

In [None]:
fig = plt.figure(figsize=(2, 5), facecolor="none")
plt.scatter(
    [],
    [],
    c=[],
    norm=mpl.colors.SymLogNorm(1, vmin=0, vmax=2_000),
    cmap=sns.color_palette("rocket", as_cmap=True),
)
plt.colorbar()
lib.plot.hide_axes_and_spines()

In [None]:
non_singleton_clusters = idxwhere(
    (clust[lambda x: x.clust >= 0].clust_label.value_counts() > 1)
)
len(non_singleton_clusters)

### KEGG Modules

In [None]:
_annot = kegg_module.rename(columns={"kegg_module": "annot"})

all_counted_genes_clust_info = clust[lambda x: x.clust > 0]
obs_clust_annot_count = (
    pd.merge(all_counted_genes_clust_info, _annot, on="centroid_75", how="left")[
        ["clust_label", "annot"]
    ]
    .value_counts()
    .groupby("clust_label")
    .max()
)

np.random.seed(0)
n_perm = 100
perm_clust_annot_count_list = []
for i in tqdm(range(n_perm)):
    perm_clust_annot_count = (
        pd.merge(
            all_counted_genes_clust_info.groupby("species").apply(
                lambda d: d.assign(
                    clust_label=lambda x: x.clust_label.sample(frac=1).values
                )
            ),
            _annot,
            on="centroid_75",
            how="left",
        )[["clust_label", "annot"]]
        .value_counts()
        .groupby("clust_label")
        .max()
    )
    perm_clust_annot_count_list.append(perm_clust_annot_count)


bins = np.arange(obs_clust_annot_count.max()) + 1

perm_histogram2d = pd.DataFrame(
    np.stack(
        [
            np.histogram(perm_clust_annot_count, bins=bins)[0]
            for perm_clust_annot_count in perm_clust_annot_count_list
        ]
    ),
    columns=bins[:-1],
)
obs_histogram = pd.Series(
    np.histogram(obs_clust_annot_count, bins=bins)[0], index=bins[:-1]
)
sns.clustermap(
    perm_histogram2d,
    norm=mpl.colors.SymLogNorm(linthresh=1),
    row_cluster=False,
    col_cluster=False,
)

kegg_module_bins = bins
kegg_module_perm_clust_annot_count = perm_clust_annot_count
kegg_module_obs_clust_annot_count = obs_clust_annot_count
kegg_module_perm_histogram2d = perm_histogram2d
kegg_module_obs_histogram = obs_histogram

In [None]:
fig, ax = plt.subplots(figsize=(4, 2))

_perm = kegg_module_perm_histogram2d
_obs = kegg_module_obs_clust_annot_count
_bins = kegg_module_bins

# # Plot the permutation histograms
# for perm_clust_annot_count in perm_clust_annot_count_list:
#     ax.hist(perm_clust_annot_count, bins=bins, align='left', color='k', alpha=0.5 / n_perm)
# # Plot the expected values.
# ax.stairs(perm_histogram2d.mean(0), bins - 0.5, edgecolor='k', facecolor='none', lw=1)

# Plot the mean
ax.stairs(
    _perm.mean(0),
    _bins - 0.5,
    edgecolor="none",
    facecolor="k",
    lw=1,
    fill=True,
    alpha=0.2,
)
ax.stairs(
    _perm.mean(0),
    _bins - 0.5,
    edgecolor="k",
    facecolor="none",
    lw=1,
    fill=True,
    alpha=1.0,
)

# ax.hist(perm_histogram2d.mean(0), bins=bins, align='left', color='k', alpha=0.5)
# ax.hist(perm_histogram2d.mean(0), bins=bins, align='left', color='k', histtype='step')

ax.hist(_obs, bins=_bins, align="left", histtype="step", color="tab:orange")
# ax.hist(obs_clust_annot_count, bins=bins, align='left', histtype="stepfilled", color='tab:orange', alpha=0.2)
ax.set_yscale("symlog", linthresh=1, linscale=0.1)

ax.scatter(
    [],
    [],
    edgecolor="tab:orange",
    facecolor="none",
    lw=1,
    label="observed",
    marker="s",
    s=40,
)
ax.scatter(
    [], [], edgecolor="k", facecolor="silver", label="null", marker="s", s=40, lw=1
)
ax.legend(loc="upper right")

ax.set_xticks(
    np.concatenate([np.arange(1, 8, step=1), np.arange(10, _obs.max() + 1, step=5)])
)
ax.set_xlabel("Max Annotations")
ax.set_ylabel("Clusters (count)")
ax.set_ylim(0, 1e5)
ax.set_yticks(np.logspace(0, 5, num=6))

In [None]:
_perm = kegg_module_perm_histogram2d
_obs = kegg_module_obs_histogram

obs_num_clust_ge3 = _obs.iloc[::-1].cumsum().loc[3]
null_num_clust_ge3 = _perm.iloc[:, ::-1].cumsum(1)[3]

plt.hist(null_num_clust_ge3)
plt.axvline(obs_num_clust_ge3, color="k")

d = pd.DataFrame(
    dict(
        obs=_obs.iloc[::-1].cumsum().iloc[::-1],
        perm=_perm.iloc[:, ::-1].cumsum(1).iloc[:, ::-1].mean(),
    )
).assign(ratio=lambda x: x.obs / x.perm)
d

### Plasmid / Phage

In [None]:
_annot = (
    pd.DataFrame(
        dict(
            phage=phage.assign(phage=True)[["centroid_75", "phage"]]
            .drop_duplicates()
            .set_index("centroid_75")
            .phage,
            plasmid=plasmid.assign(plasmid=True)[["centroid_75", "plasmid"]]
            .drop_duplicates()
            .set_index("centroid_75")
            .plasmid,
        )
    )
    .stack()
    .sort_values()
    .rename_axis(["centroid_75", "annot"])
    .reset_index()
    .drop(columns=[0])
)

all_counted_genes_clust_info = clust[lambda x: x.clust > 0]
obs_clust_annot_count = (
    pd.merge(all_counted_genes_clust_info, _annot, on="centroid_75", how="left")[
        ["clust_label", "annot"]
    ]
    .value_counts()
    .groupby("clust_label")
    .max()
)

np.random.seed(0)
n_perm = 100
perm_clust_annot_count_list = []
for i in tqdm(range(n_perm)):
    perm_clust_annot_count = (
        pd.merge(
            all_counted_genes_clust_info.groupby("species").apply(
                lambda d: d.assign(
                    clust_label=lambda x: x.clust_label.sample(frac=1).values
                )
            ),
            _annot,
            on="centroid_75",
            how="left",
        )[["clust_label", "annot"]]
        .value_counts()
        .groupby("clust_label")
        .max()
    )
    perm_clust_annot_count_list.append(perm_clust_annot_count)


bins = np.arange(obs_clust_annot_count.max()) + 1

perm_histogram2d = pd.DataFrame(
    np.stack(
        [
            np.histogram(perm_clust_annot_count, bins=bins)[0]
            for perm_clust_annot_count in perm_clust_annot_count_list
        ]
    ),
    columns=bins[:-1],
)
obs_histogram = pd.Series(
    np.histogram(obs_clust_annot_count, bins=bins)[0], index=bins[:-1]
)
sns.clustermap(
    perm_histogram2d,
    norm=mpl.colors.SymLogNorm(linthresh=1),
    row_cluster=False,
    col_cluster=False,
)

plasmid_phage_bins = bins
plasmid_phage_perm_clust_annot_count = perm_clust_annot_count
plasmid_phage_obs_clust_annot_count = obs_clust_annot_count
plasmid_phage_perm_histogram2d = perm_histogram2d
plasmid_phage_obs_histogram = obs_histogram

In [None]:
fig, ax = plt.subplots(figsize=(4, 2))

_perm = plasmid_phage_perm_histogram2d
_obs = plasmid_phage_obs_clust_annot_count
_bins = plasmid_phage_bins

# # Plot the permutation histograms
# for perm_clust_annot_count in perm_clust_annot_count_list:
#     ax.hist(perm_clust_annot_count, bins=bins, align='left', color='k', alpha=0.5 / n_perm)
# # Plot the expected values.
# ax.stairs(perm_histogram2d.mean(0), bins - 0.5, edgecolor='k', facecolor='none', lw=1)

# Plot the mean
ax.stairs(
    _perm.mean(0),
    _bins - 0.5,
    edgecolor="none",
    facecolor="k",
    lw=1,
    fill=True,
    alpha=0.2,
)
ax.stairs(
    _perm.mean(0),
    _bins - 0.5,
    edgecolor="k",
    facecolor="none",
    lw=1,
    fill=True,
    alpha=1.0,
)

# ax.hist(perm_histogram2d.mean(0), bins=bins, align='left', color='k', alpha=0.5)
# ax.hist(perm_histogram2d.mean(0), bins=bins, align='left', color='k', histtype='step')

ax.hist(_obs, bins=_bins, align="left", histtype="step", color="tab:orange")
# ax.hist(obs_clust_annot_count, bins=bins, align='left', histtype="stepfilled", color='tab:orange', alpha=0.2)
ax.set_yscale("symlog", linthresh=1, linscale=0.1)

ax.scatter(
    [],
    [],
    edgecolor="tab:orange",
    facecolor="none",
    lw=1,
    label="observed",
    marker="s",
    s=40,
)
ax.scatter(
    [], [], edgecolor="k", facecolor="silver", label="null", marker="s", s=40, lw=1
)
ax.legend(loc="upper right")

ax.set_xticks(
    np.concatenate([np.arange(1, 8, step=2), np.arange(10, _obs.max() + 1, step=5)])
)
ax.set_xlabel("Max Annotations")
ax.set_ylabel("Clusters (count)")
ax.set_ylim(0, 1e5)
ax.set_yticks(np.logspace(0, 5, num=6))

In [None]:
_perm = plasmid_phage_perm_histogram2d
_obs = plasmid_phage_obs_histogram

obs_num_clust_ge3 = _obs.iloc[::-1].cumsum().loc[3]
null_num_clust_ge3 = _perm.iloc[:, ::-1].cumsum(1)[3]

plt.hist(null_num_clust_ge3)
plt.axvline(obs_num_clust_ge3, color="k")

d = pd.DataFrame(
    dict(
        obs=_obs.iloc[::-1].cumsum().iloc[::-1],
        perm=_perm.iloc[:, ::-1].cumsum(1).iloc[:, ::-1].mean(),
    )
).assign(ratio=lambda x: x.obs / x.perm)
d