## Preamble

### Project Template

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

### Imports

In [None]:
import os
import subprocess
import sys
import time
from datetime import datetime
from glob import glob
from itertools import chain, product
from tempfile import mkstemp

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import sfacts as sf
import statsmodels.api as sm
import statsmodels.formula.api as smf
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
from lib.dissimilarity import load_dmat_as_pickle
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

In [None]:
import lib.thisproject.data

### Set Style

In [None]:
sns.set_context("talk")
plt.rcParams["figure.dpi"] = 100

## Metadata

In [None]:
species_list = (
    pd.read_table("meta/species_group.tsv")[
        lambda x: x.species_group_id == "hmp2"
    ]
    .species_id.astype(str)
    .unique()
)

In [None]:
def parse_taxonomy_string(taxonomy_string):
    values = taxonomy_string.split(";")
    return pd.Series(values, index=["d__", "p__", "c__", "o__", "f__", "g__", "s__"])

In [None]:
species_taxonomy_inpath = "ref/uhgg_genomes_all_v2.tsv"

species_taxonomy = (
    pd.read_table(species_taxonomy_inpath)[lambda x: x.Genome == x.Species_rep]
    .assign(species_id=lambda x: "1" + x.MGnify_accession.str.split("-").str[2])
    .set_index("species_id")
    .Lineage.apply(parse_taxonomy_string)
)
species_taxonomy

In [None]:
phylum_order = [
    "p__Euryarchaeota",
    "p__Thermoplasmatota",
    "p__Firmicutes",
    "p__Firmicutes_A",
    "p__Firmicutes_C",
    # "p__Firmicutes_B", # None in species_list1
    # "p__Firmicutes_G", # B/G/I not sure how related to C or A
    # "p__Firmicutes_I", #
    # "p__Cyanobacteria", # None in species_list1
    "p__Actinobacteriota",
    "p__Synergistota",
    "p__Fusobacteriota",
    "p__Campylobacterota",
    "p__Proteobacteria",
    "p__Desulfobacterota_A",
    "p__Bacteroidota",
    "p__Verrucomicrobiota",
    # "dummy0", # 18
    # "dummy1", # 19
    # "dummy2", # 20
]

phylum_palette = lib.plot.construct_ordered_palette(
    phylum_order,
    cm="rainbow",
    desaturate_levels=[1.0, 0.5],
)

for p__ in phylum_order:
    print(p__, phylum_palette[p__])
    plt.scatter([], [], color=phylum_palette[p__], label=p__)
plt.legend(ncols=4)
lib.plot.hide_axes_and_spines()

# assert len(set(phylum_palette.values())) == len((phylum_palette.values()))

## Analysis

In [None]:
species = '103166'

cog_category_inpath = f'data/species/sp-{species}/midasdb_v15.emapper.gene75_x_cog_category.tsv'
eggnog_inpath = f'data/species/sp-{species}/midasdb_v15.emapper.gene75_x_eggnog.tsv'
ko_inpath = f'data/species/sp-{species}/midasdb_v15.emapper.gene75_x_ko.tsv'
amr_inpath = f'data/species/sp-{species}/midasdb_v15.gene75_x_amr.tsv'
plasmid_inpath = f'data/species/sp-{species}/midasdb_v15.gene75_x_genomad_plasmid.tsv'
phage_inpath = f'data/species/sp-{species}/midasdb_v15.gene75_x_genomad_virus.tsv'
emapper_inpath = f'ref/midasdb_uhgg_v15/pangenomes/{species}/eggnog.tsv'
clust_inpath = f'data/group/hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_v15-v22-agg75.spgc-fit.uhgg-strain_gene.gene_clust-t10.tsv'
gene_family_inpath = f'ref/midasdb_uhgg_v15/pangenomes/{species}/gene_info.txt'
morans_i_inpath = f'data/group/hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_v15-v22-agg75.spgc-fit.uhgg-strain_gene.morans_i.tsv'
prevalence_inpath = f'data/group/hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_v15-v22-agg75.spgc-fit.uhgg-strain_gene.prevalence.tsv'
kegg_module_inpath = f'data/species/sp-{species}/midasdb_v15.emapper.gene75_x_kegg_module.tsv'

pdist_inpath = f'data/group/hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.spgc_ss-all.geno_uhgg-v15_pdist-mask10-pseudo10.pkl'
strain_gene_inpath = f"data/group/hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_v15-v22-agg75.spgc-fit.uhgg-strain_gene.tsv"

In [None]:
gene_family = pd.read_table(gene_family_inpath, index_col="gene_id")

In [None]:
kegg_module = pd.read_table(kegg_module_inpath)

In [None]:
emapper_all = pd.read_table(emapper_inpath, index_col='#query').rename_axis('centroid_99')
description = emapper_all.reindex(gene_family.centroid_75.unique()).dropna(subset=['seed_ortholog']).rename_axis('centroid_75')

In [None]:
phage = pd.read_table(phage_inpath)#.groupby('centroid_75')
plasmid = pd.read_table(plasmid_inpath)
amr = pd.read_table(amr_inpath)

clust = pd.read_table(clust_inpath, names=['centroid_75', 'clust'], index_col='centroid_75').clust
clust_size = clust.value_counts()
# clust = clust[lambda x: (x >= 0) & x.isin(idxwhere(clust_size > 1))]

morans_i = pd.read_table(morans_i_inpath, names=['centroid_75', 'morans_i'], index_col='centroid_75').morans_i.dropna()
prevalence = pd.read_table(prevalence_inpath, names=['centroid_75', 'prevalence'], index_col='centroid_75').prevalence

In [None]:
eggnog = pd.read_table(eggnog_inpath)
cog_category = pd.read_table(cog_category_inpath)

In [None]:
from scipy.cluster.hierarchy import linkage
from scipy.spatial.distance import squareform

strain_gene = pd.read_table(strain_gene_inpath, index_col='gene_id')
geno_dmat = lib.dissimilarity.load_dmat_as_pickle(pdist_inpath).loc[strain_gene.columns, strain_gene.columns]
geno_linkage = linkage(squareform(geno_dmat), optimal_ordering=True)

In [None]:
d = description[['Preferred_name', 'Description']].reindex(gene_family.centroid_75.unique()).assign(morans_i=morans_i).assign(
    cog_category=cog_category.groupby('centroid_75').cog_category.apply(''.join),
    eggnog=eggnog.groupby('centroid_75').eggnog.apply(';'.join),
    kegg_module=kegg_module.groupby('centroid_75').kegg_module.apply(';'.join),
    clust=clust,
    csize=lambda x: x.clust.map(clust_size),
    plasmid=plasmid.groupby('centroid_75').annotation_accessions.apply(';'.join),
    phage=phage.groupby('centroid_75').annotation_accessions.apply(';'.join),
    amr=amr.groupby('centroid_75').accession_no.apply(';'.join),
    prevalence=prevalence,
)

# d[lambda x: ~x.clust.isna() & (prevalence > 0.1) & (prevalence < 0.9)].sort_values('csize', ascending=False).head(10)

In [None]:
# clust_inpath = f'data/group/hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_v15-v22-agg75.spgc-fit.uhgg-strain_gene.gene_clust-t10.tsv'
# kegg_module_inpath = f'data/species/sp-{species}/midasdb_v15.emapper.gene75_x_kegg_module.tsv'
# clust = pd.read_table(clust_inpath, names=['centroid_75', 'clust'], index_col='centroid_75').clust
# clust_size = clust.value_counts()
# kegg_module = pd.read_table(kegg_module_inpath)


x = kegg_module.join(clust, on='centroid_75')[['clust', 'kegg_module']][lambda x: x.clust >= 0]
non_singleton_clusters = idxwhere(clust_size > 1)



fig, ax = plt.subplots()
# ax2 = plt.twinx(ax1)
# ax2.invert_yaxis()

xmax = 15
bins = np.arange(xmax) - 0.5

# ax.hist(x[lambda y: y.clust.isin(non_singleton_clusters)].value_counts(), bins=np.arange(xmax) - 0.5, alpha=0.5, align='mid', color='tab:blue')
obs_counts = x[lambda y: y.clust.isin(non_singleton_clusters)].value_counts()
ax.hist(obs_counts, bins=bins, alpha=1.0, histtype='step', lw=2, align='mid', color='tab:orange', density=False, label='__nolegend__')


np.random.seed(0)
n_reps = 100
all_perm_counts = []
for i in range(n_reps):
    x_perm = x.assign(clust=lambda x: x.clust.sample(frac=1).values)
    x_counts = x_perm[lambda y: y.clust.isin(non_singleton_clusters)].value_counts()
    all_perm_counts.append(x_counts)
    ax.hist(x_counts, bins=bins, alpha=0.8 / n_reps, align='mid', color='k', density=False, label='__nolegend__')

ax.scatter([], [], edgecolor='tab:orange', facecolor='none', lw=2, label='observed', marker='s', s=200)
ax.scatter([], [], c='black', alpha=0.2, label='null', marker='s', s=200, lw=0)
ax.legend(loc='upper right')

ax.set_xticks([1, 3, 5, 7, 9, 11])
ax.set_xlabel('Genes per Module')
ax.set_ylabel('Cluster-Modules (count)')

ax.set_yscale('symlog', linthresh=1, linscale=0.1)

# # ax1.set_yscale('symlog')
# # ax2.set_yscale('symlog')
# ax.set_ylim(0, 1000)
# ax.set_yticks(np.array([1e-3, 1e-2, 1e-1, 1]))

x[lambda y: y.clust.isin(non_singleton_clusters)].value_counts().to_frame('tally').reset_index().join(clust_size, on='clust').head(20)

In [None]:
# Simple summary of permutation test results for modules

for min_num_counts in [2, 3, 4, 5, 6, 7, 8]:
    print(
        # Threshold num genes
        min_num_counts,
        # Observed num clusters
        (obs_counts >= min_num_counts).sum(),
        # Max result under permutation
        np.array([(perm_counts >= min_num_counts).sum() for perm_counts in all_perm_counts]).max(),
    )

In [None]:
# clust_inpath = f'data/group/hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_v15-v22-agg75.spgc-fit.uhgg-strain_gene.gene_clust-t10.tsv'
# kegg_module_inpath = f'data/species/sp-{species}/midasdb_v15.emapper.gene75_x_kegg_module.tsv'
# cog_category_inpath = f'data/species/sp-{species}/midasdb_v15.emapper.gene75_x_cog_category.tsv'
# clust = pd.read_table(clust_inpath, names=['centroid_75', 'clust'], index_col='centroid_75').clust
# clust_size = clust.value_counts()
# cog_category = pd.read_table(cog_category_inpath)


x = cog_category[lambda x: x.cog_category != 'no_category'].join(clust, on='centroid_75')[['clust', 'cog_category']][lambda x: x.clust >= 0]
non_singleton_clusters = idxwhere(clust_size > 1)



fig, ax = plt.subplots()
# ax2 = plt.twinx(ax1)
# ax2.invert_yaxis()

xmax = 15
bins = np.arange(xmax) - 0.5

# ax.hist(x[lambda y: y.clust.isin(non_singleton_clusters)].value_counts(), bins=np.arange(xmax) - 0.5, alpha=0.5, align='mid', color='tab:blue')
obs_counts = x[lambda y: y.clust.isin(non_singleton_clusters)].value_counts()
ax.hist(obs_counts, bins=bins, alpha=1.0, histtype='step', lw=2, align='mid', color='tab:orange', density=False, label='__nolegend__')


np.random.seed(0)
n_reps = 100
all_perm_counts = []
for i in range(n_reps):
    x_perm = x.assign(clust=lambda x: x.clust.sample(frac=1).values)
    x_counts = x_perm[lambda y: y.clust.isin(non_singleton_clusters)].value_counts()
    all_perm_counts.append(x_counts)
    ax.hist(x_counts, bins=bins, alpha=0.8 / n_reps, align='mid', color='k', density=False, label='__nolegend__')

ax.scatter([], [], edgecolor='tab:orange', facecolor='none', lw=2, label='observed', marker='s', s=200)
ax.scatter([], [], c='black', alpha=0.2, label='null', marker='s', s=200, lw=0)
ax.legend(loc='upper right')

ax.set_xticks([1, 3, 5, 7, 9, 11])
ax.set_xlabel('Genes per Module')
ax.set_ylabel('Cluster-Categories (count)')

ax.set_yscale('symlog', linthresh=1, linscale=0.1)

# # ax1.set_yscale('symlog')
# # ax2.set_yscale('symlog')
# ax.set_ylim(0, 1000)
# ax.set_yticks(np.array([1e-3, 1e-2, 1e-1, 1]))

x[lambda y: y.clust.isin(non_singleton_clusters)].value_counts().to_frame('tally').reset_index().join(clust_size, on='clust').head(20)

In [None]:
# Simple summary of permutation test results for modules

for min_num_counts in [2, 3, 4, 5, 6, 7, 8]:
    print(
        # Threshold num genes
        min_num_counts,
        # Observed num clusters
        (obs_counts >= min_num_counts).sum(),
        # Max result under permutation
        np.array([(perm_counts >= min_num_counts).sum() for perm_counts in all_perm_counts]).max(),
    )

In [None]:
d[lambda x: x.clust == 2441].head(50)

In [None]:
sns.clustermap(strain_gene[clust == 2441], col_linkage=geno_linkage)

## Higher rate of same-categories and same-module (plasmids?) matching in clusters

In [None]:
# Gather clusters for all species (among a particular species list, that is)
# Gather annotations for all genes in these clusters
# Put everything together and do the same analysis as above

In [None]:
species_list4 = [
    "102492",
    "103694",
    "100022",
    "102545",
    "102272",
    "102478",
    "101300",
    "101346",
    "102438",
    "101302",
    "100254",
    "100217",
    "102549",
    "100271",
    "101378",
    "100196",
    "103899",
    "102528",
    "100074",
    "102321",
    "101345",
    "104158",
    "100251",
    "100562",
    "100099",
    "100209",
    "100078",
    "101830",
    "103681",
    "100003",
    "103702",
    "101337",
    "100044",
    "103937",
    "102517",
    "103166",
    "103686",
    "102506",
    "102327",
    "101338",
    "102040",
    "102274",
    "100205",
    "102292",
    "100208",
    "100144",
    "101292",
    "100038",
    "103439",
    "100233",
    "102454",
    "100154",
    "101374",
]

clust = []
kegg_module = []
cog_category = []
plasmid = []
phage = []

for species in tqdm(species_list4):
    clust_inpath = f'data/group/hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_v15-v22-agg75.spgc-fit.uhgg-strain_gene.gene_clust-t10.tsv'
    kegg_module_inpath = f'data/species/sp-{species}/midasdb_v15.emapper.gene75_x_kegg_module.tsv'
    cog_category_inpath = f'data/species/sp-{species}/midasdb_v15.emapper.gene75_x_cog_category.tsv'
    plasmid_inpath = f'data/species/sp-{species}/midasdb_v15.gene75_x_genomad_plasmid.tsv'
    phage_inpath = f'data/species/sp-{species}/midasdb_v15.gene75_x_genomad_virus.tsv'
    clust.append(pd.read_table(clust_inpath, names=['centroid_75', 'clust']).assign(species=species))
    kegg_module.append(pd.read_table(kegg_module_inpath))
    cog_category.append(pd.read_table(cog_category_inpath))
    plasmid.append(pd.read_table(plasmid_inpath))
    phage.append(pd.read_table(phage_inpath))

clust = pd.concat(clust).assign(clust_label=lambda x: x.clust.astype(str) + '_' + x.species)
kegg_module = pd.concat(kegg_module)
cog_category = pd.concat(cog_category)
plasmid = pd.concat(plasmid)
phage = pd.concat(phage)

In [None]:
bins = np.array([2**i for i in range(1, 11)])

d = (
    clust[lambda x: x.clust >= 0]
    .groupby("species")
    .apply(
        lambda d: pd.Series(
            np.histogram(d.clust_label.value_counts(), bins=bins)[0], index=bins[:-1]
        )
    )
)
# Order by total number of genes in clusters with size ≥2
species_order = (
    (
        clust[lambda x: x.clust >= 0][["species", "clust"]].value_counts()[
            lambda x: x > 1
        ]
    )
    .groupby("species")
    .sum()
    .sort_values(ascending=False)
    .index
)

row_colors = (
    d.index.to_series().map(species_taxonomy.p__).map(phylum_palette).rename("phylum")
)

cg = sns.clustermap(
    d.loc[species_order].drop(columns=[1], errors="ignore"),
    norm=mpl.colors.SymLogNorm(1, vmin=0, vmax=2_000),
    cmap=sns.color_palette("rocket", as_cmap=True),
    xticklabels=1,
    yticklabels=0,
    col_cluster=False,
    row_cluster=False,
    row_colors=row_colors,
    figsize=(6, 5),
    cbar_pos=None,
)

In [None]:
fig = plt.figure(figsize=(2, 5), facecolor='none')
plt.scatter([], [], c=[], norm=mpl.colors.SymLogNorm(1, vmin=0, vmax=2_000), cmap=sns.color_palette("rocket", as_cmap=True), )
plt.colorbar()
lib.plot.hide_axes_and_spines()

In [None]:
non_singleton_clusters = idxwhere((clust[lambda x: x.clust >= 0].clust_label.value_counts() > 1))
len(non_singleton_clusters)

### KEGG Modules

In [None]:
_annot = kegg_module.rename(columns={"kegg_module": "annot"})

all_counted_genes_clust_info = clust[lambda x: x.clust > 0]
obs_clust_annot_count = pd.merge(
    all_counted_genes_clust_info, _annot, on="centroid_75", how="left"
)[["clust_label", "annot"]].value_counts().groupby('clust_label').max()

np.random.seed(0)
n_perm = 100
perm_clust_annot_count_list = []
for i in tqdm(range(n_perm)):
    perm_clust_annot_count = pd.merge(
        all_counted_genes_clust_info.groupby("species").apply(
            lambda d: d.assign(clust_label=lambda x: x.clust_label.sample(frac=1).values)
        ),
        _annot,
        on="centroid_75",
        how="left",
    )[["clust_label", "annot"]].value_counts().groupby('clust_label').max()
    perm_clust_annot_count_list.append(perm_clust_annot_count)

In [None]:
bins = np.arange(obs_clust_annot_count.max()) + 1


perm_histogram2d = pd.DataFrame(np.stack([np.histogram(perm_clust_annot_count, bins=bins)[0] for perm_clust_annot_count in perm_clust_annot_count_list]), columns=bins[:-1])
obs_histogram = pd.Series(np.histogram(obs_clust_annot_count, bins=bins)[0], index=bins[:-1])

sns.clustermap(perm_histogram2d, norm=mpl.colors.SymLogNorm(linthresh=1), row_cluster=False, col_cluster=False, )

In [None]:
fig, ax = plt.subplots(figsize=(6, 3))

for perm_clust_annot_count in perm_clust_annot_count_list:
    ax.hist(perm_clust_annot_count, bins=bins, align='left', color='k', alpha=0.5 / n_perm)

ax.stairs(perm_histogram2d.mean(0), bins - 0.5, edgecolor='k', facecolor='none', lw=1)

ax.hist(obs_clust_annot_count, bins=bins, align='left', histtype="step", color='tab:orange', )
ax.set_yscale('symlog', linthresh=1, linscale=0.1)

ax.scatter([], [], edgecolor='tab:orange', facecolor='none', lw=2, label='observed', marker='s', s=200)
ax.scatter([], [], edgecolor='k', facecolor='silver', label='null', marker='s', s=200, lw=1)
ax.legend(loc='upper right')

ax.set_xticks(np.concatenate([np.arange(1, 8, step=2), np.arange(10, obs_clust_annot_count.max() + 1, step=5)]))
ax.set_xlabel('Module Annotations')
ax.set_ylabel('Clusters (count)')
ax.set_ylim(0, 1e5)
ax.set_yticks(np.logspace(0, 5, num=6))

In [None]:
(perm_histogram2d.iloc[:,::-1].cumsum(1).iloc[:,::-1].mean() / obs_histogram.iloc[::-1].cumsum().iloc[::-1]).head(20)

### Plasmid / Phage

In [None]:
_annot = pd.DataFrame(dict(phage=phage.assign(phage=True)[['centroid_75', 'phage']].drop_duplicates().set_index('centroid_75').phage,
                  plasmid=plasmid.assign(plasmid=True)[['centroid_75', 'plasmid']].drop_duplicates().set_index('centroid_75').plasmid,
                 )).stack().sort_values().rename_axis(['centroid_75', 'annot']).reset_index().drop(columns=[0])


all_counted_genes_clust_info = clust[lambda x: x.clust > 0]
obs_clust_annot_count = pd.merge(
    all_counted_genes_clust_info, _annot, on="centroid_75", how="left"
)[["clust_label", "annot"]].value_counts().groupby('clust_label').max()

np.random.seed(0)
n_perm = 100
perm_clust_annot_count_list = []
for i in tqdm(range(n_perm)):
    perm_clust_annot_count = pd.merge(
        all_counted_genes_clust_info.groupby("species").apply(
            lambda d: d.assign(clust_label=lambda x: x.clust_label.sample(frac=1).values)
        ),
        _annot,
        on="centroid_75",
        how="left",
    )[["clust_label", "annot"]].value_counts().groupby('clust_label').max()
    perm_clust_annot_count_list.append(perm_clust_annot_count)

In [None]:
bins = np.arange(obs_clust_annot_count.max()) + 1

perm_histogram2d = pd.DataFrame(np.stack([np.histogram(perm_clust_annot_count, bins=bins)[0] for perm_clust_annot_count in perm_clust_annot_count_list]), columns=bins[:-1])
obs_histogram = pd.Series(np.histogram(obs_clust_annot_count, bins=bins)[0], index=bins[:-1])

sns.clustermap(perm_histogram2d, norm=mpl.colors.SymLogNorm(linthresh=1), row_cluster=False, col_cluster=False, )

In [None]:
fig, ax = plt.subplots(figsize=(6, 3))

for perm_clust_annot_count in perm_clust_annot_count_list:
    ax.hist(perm_clust_annot_count, bins=bins, align='left', color='k', alpha=0.5 / n_perm)

ax.stairs(perm_histogram2d.mean(0), bins - 0.5, edgecolor='k', facecolor='none', lw=1)

ax.hist(obs_clust_annot_count, bins=bins, align='left', histtype="step", color='tab:orange', )
ax.set_yscale('symlog', linthresh=1, linscale=0.1)

ax.scatter([], [], edgecolor='tab:orange', facecolor='none', lw=2, label='observed', marker='s', s=200)
ax.scatter([], [], edgecolor='k', facecolor='silver', label='null', marker='s', s=200, lw=1)
ax.legend(loc='upper right')

ax.set_xticks(np.concatenate([np.arange(1, 8, step=2), np.arange(10, obs_clust_annot_count.max() + 1, step=5)]))
ax.set_xlabel('Phage or Plasmid Annotations')
ax.set_ylabel('Clusters (count)')
ax.set_ylim(0, 1e5)

In [None]:
(perm_histogram2d.iloc[:,::-1].cumsum(1).iloc[:,::-1].mean() / obs_histogram.iloc[::-1].cumsum().iloc[::-1]).head(20)

### COG Categories

In [None]:
_annot = cog_category[lambda x: x.cog_category != 'no_category'].rename(columns={"cog_category": "annot"})



all_counted_genes_clust_info = clust[lambda x: x.clust > 0]
obs_clust_annot_count = pd.merge(
    all_counted_genes_clust_info, _annot, on="centroid_75", how="left"
)[["clust_label", "annot"]].value_counts().groupby('clust_label').max()

np.random.seed(0)
n_perm = 100
perm_clust_annot_count_list = []
for i in tqdm(range(n_perm)):
    perm_clust_annot_count = pd.merge(
        all_counted_genes_clust_info.groupby("species").apply(
            lambda d: d.assign(clust_label=lambda x: x.clust_label.sample(frac=1).values)
        ),
        _annot,
        on="centroid_75",
        how="left",
    )[["clust_label", "annot"]].value_counts().groupby('clust_label').max()
    perm_clust_annot_count_list.append(perm_clust_annot_count)

In [None]:
bins = np.arange(obs_clust_annot_count.max()) + 1

perm_histogram2d = pd.DataFrame(np.stack([np.histogram(perm_clust_annot_count, bins=bins)[0] for perm_clust_annot_count in perm_clust_annot_count_list]), columns=bins[:-1])
obs_histogram = pd.Series(np.histogram(obs_clust_annot_count, bins=bins)[0], index=bins[:-1])

sns.clustermap(perm_histogram2d, norm=mpl.colors.SymLogNorm(linthresh=1), row_cluster=False, col_cluster=False, )

In [None]:
fig, ax = plt.subplots(figsize=(6, 3))

for perm_clust_annot_count in perm_clust_annot_count_list:
    ax.hist(perm_clust_annot_count, bins=bins, align='left', color='k', alpha=0.5 / n_perm)

ax.stairs(perm_histogram2d.mean(0), bins - 0.5, edgecolor='k', facecolor='none', lw=1)

ax.hist(obs_clust_annot_count, bins=bins, align='left', histtype="step", color='tab:orange', )
ax.set_yscale('symlog', linthresh=1, linscale=0.1)

ax.scatter([], [], edgecolor='tab:orange', facecolor='none', lw=2, label='observed', marker='s', s=200)
ax.scatter([], [], edgecolor='k', facecolor='silver', label='null', marker='s', s=200, lw=1)
ax.legend(loc='upper right')

ax.set_xticks(np.concatenate([np.arange(1, 8, step=2), np.arange(10, obs_clust_annot_count.max() + 1, step=5)]))
ax.set_xlabel('Num. Genes (most common)')
ax.set_ylabel('Clusters (count)')

In [None]:
(perm_histogram2d.iloc[:,::-1].cumsum(1).iloc[:,::-1].mean() / obs_histogram.iloc[::-1].cumsum().iloc[::-1]).head(20)