In [None]:
%load_ext autoreload

In [None]:
import os
os.chdir('..')
os.path.realpath(os.path.curdir)

In [None]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
from lib.pandas_util import idxwhere, align_indexes, invert_mapping
import matplotlib as mpl
import lib.plot
import statsmodels as sm
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm
import subprocess
from tempfile import mkstemp
import time
import subprocess
from itertools import chain
from scipy.spatial.distance import pdist, squareform

In [None]:
import sfacts as sf

In [None]:
# sns.set_context('talk')
# plt.rcParams['figure.dpi'] = 100
stemA = 'hmp2.a.r.proc'
centroid = 75
stemB = 'filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts37-s80-seed0'
# stemB = 'filt-poly05-cvrg05.ss-g10000-block0-seed0.approx-clust2-thresh05-s95'

# Analysis Parameters

In [None]:
species_id = '102506'
# depth_ratio_bound = 3
# species_gene_corr_thresh = 0.99
# n_species_genes = 2000
species_gene_corr_thresh = 0.98

# Prepare Data

## Taxonomy

In [None]:
species_taxonomy = pd.read_table('ref/gtpro/species_taxonomy_ext.tsv', names=['genome_id', 'species_id', 'taxonomy_string']).assign(species_id=lambda x: x.species_id.astype(str)).set_index('species_id')[['taxonomy_string']].assign(taxonomy_split=lambda x: x.taxonomy_string.str.split(';'))

for level_name, level_number in [('p__', 1), ('c__', 2), ('o__', 3), ('f__', 4), ('g__', 5), ('s__', 6)]:
    species_taxonomy = species_taxonomy.assign(**{level_name: species_taxonomy.taxonomy_split.apply(lambda x: x[level_number])}) 
species_taxonomy = species_taxonomy.drop(columns=['taxonomy_split'])

species_taxonomy.loc[species_id]

In [None]:
midasdb_genomes = pd.read_table('ref_temp/uhgg_genomes_all_4644.tsv')

In [None]:
midasdb_genomes.Species_rep.value_counts()

## Species

In [None]:
all_species_depth = pd.read_table(f'data/{stemA}.gtpro.species_depth.tsv', index_col=['sample', 'species_id']).squeeze().unstack('species_id', fill_value=0).rename(str, axis='columns')
species_rabund = all_species_depth.divide(all_species_depth.sum(1), axis=0)

In [None]:
gtpro_species_depth = pd.read_table(f'data/sp-{species_id}.{stemA}.gtpro.species_depth.tsv', dtype=dict(sample=str, species_id=str, depth=float), index_col=['sample', 'species_id']).squeeze().unstack('species_id')

In [None]:
species_depth = pd.read_table(f'data_temp/sp-{species_id}.{stemA}.midas_gene{centroid}.species_depth.tsv', names=['sample', 'depth'], index_col='sample').squeeze()

In [None]:
gene_depth = xr.load_dataarray(f'data_temp/sp-{species_id}.{stemA}.midas_gene{centroid}.depth.nc').sel(sample=species_depth.index)

In [None]:
d = pd.DataFrame(dict(gtpro=gtpro_species_depth[species_id], midas=species_depth))

plt.scatter('gtpro', 'midas', data=d, s=3, alpha=0.3)
plt.plot([0, 1e2], [0, 1e2])
plt.yscale('symlog', linthresh=1e-4)
plt.xscale('symlog', linthresh=1e-4)

In [None]:
d = pd.DataFrame(dict(gtpro=gtpro_species_depth[species_id], midas=species_depth))

plt.scatter('gtpro', 'midas', data=np.cbrt(d), s=3, alpha=0.3)
# plt.plot([0, 1e2], [0, 1e2])
# plt.yscale('symlog', linthresh=1e-4)
# plt.xscale('symlog', linthresh=1e-4)

In [None]:
species_corr = pd.read_table(f'data_temp/sp-{species_id}.{stemA}.midas_gene{centroid}.species_correlation.tsv', names=['sample', 'correlation'], index_col='sample').squeeze()

## Metadata

In [None]:
mgen = pd.read_table('meta/hmp2/mgen.tsv', index_col='library_id')
preparation = pd.read_table('meta/hmp2/preparation.tsv', index_col='preparation_id')
stool = pd.read_table('meta/hmp2/stool.tsv', index_col='stool_id')
subject = pd.read_table('meta/hmp2/subject.tsv', index_col='subject_id')

sample_meta = mgen.join(preparation, on='preparation_id', rsuffix='_').join(stool, on='stool_id').join(subject, on='subject_id').loc[all_species_depth.index]

In [None]:
len(sample_meta.stool_id.unique()), len(sample_meta.subject_id.unique())

## Strains

In [None]:
fit = sf.World.load(
    f'data_temp/sp-{species_id}.{stemA}.gtpro.{stemB}.world.nc'
).drop_low_abundance_strains(0.05)
print(fit.sizes)

np.random.seed(0)
position_ss = fit.random_sample(position=min(fit.sizes['position'], 1000)).position

In [None]:
strain_corr = pd.read_table(
    f'data_temp/sp-{species_id}.{stemA}.gtpro.{stemB}.midas_gene{centroid}.strain_correlation.tsv',
    index_col=['gene_id', 'strain']
).squeeze().unstack(fill_value=0)
# strain_corr = strain_by_species_corr.sel(species_id=species_id).to_series().unstack('strain')
strain_depth = pd.read_table(
    f'data_temp/sp-{species_id}.{stemA}.gtpro.{stemB}.midas_gene{centroid}.strain_depth_ratio.tsv',
    index_col=['gene_id', 'strain']
).squeeze().unstack()
strain_corr, strain_depth = align_indexes(*align_indexes(strain_corr, strain_depth), axis="columns")

In [None]:
sample_to_strain = (
    (fit.community.data > 0.95)
    .to_series()
    .unstack()
    .apply(idxwhere, axis=1)
    [lambda x: x.apply(bool)]
    .str[0]
    .rename('strain')
)
    
strain_to_sample_list = (
    sample_to_strain
    .rename('strain_id')
    .reset_index()
    .groupby('strain_id')
    .apply(lambda x: x['sample'].to_list())
)
strain_to_sample_list.apply(len).sort_values(ascending=False).head()

In [None]:
# species_gene_corr_thresh = species_corr.sort_values(ascending=False).head(n_species_genes + 1).min()
species_gene_list = idxwhere(species_corr.loc[strain_corr.index] > species_gene_corr_thresh)
print(len(species_gene_list))

In [None]:
# strain_thresh = pd.read_table(
#     f'data_temp/sp-{species_id}.{stemA}.gtpro.{stemB}.midas_gene{centroid}.strain_correlation_threshold.tsv',
#     names=['strain_id', 'threshold'],
#     index_col='strain_id',
# ).loc[strain_corr.columns]
plt.hist(species_corr, bins=np.linspace(0, 1, num=101))
plt.axvline(species_gene_corr_thresh, linestyle=':', color='k')

In [None]:
strain_frac = pd.read_table(f'data_temp/sp-{species_id}.{stemA}.gtpro.{stemB}.comm.tsv', index_col=['sample', 'strain']).squeeze().unstack(fill_value=0)

In [None]:
plt.hist(np.log10(species_depth[species_depth > 0]), bins=np.linspace(-4, 4))
plt.axvline(np.log10(1.0), linestyle=':', color='k')

## MIDAS Genes, COGs, COG categories

In [None]:
gene_cluster = pd.read_table(
    f'ref_temp/midasdb_uhgg/pangenomes/{species_id}/cluster_info.txt'
).set_index('centroid_99', drop=False).rename_axis(index='gene_id')
gene_annotation = pd.read_table(
    f'ref_temp/midasdb_uhgg.sp-{species_id}.gene{centroid}_annotations.tsv',
    names=['locus_tag', 'ftype', 'length_bp', 'gene', 'EC_number', 'COG', 'product'],
    index_col='locus_tag',
).rename(columns=str.lower)

gene_meta = gene_cluster.loc[gene_cluster[f'centroid_{centroid}'].unique()].join(gene_annotation)

In [None]:
gene_cluster

In [None]:
_cog_meta = pd.read_table(
    'ref/cog-20.meta.tsv',
    names=['cog', 'categories', 'description', 'gene', 'pathway', '_1', '_2'],
    index_col=['cog']
)
cog_meta = _cog_meta.drop(columns=['categories', '_1', '_2'])
cog_x_category = _cog_meta.categories.apply(tuple).apply(pd.Series).unstack().to_frame(name='category').reset_index()[['cog', 'category']].dropna()

In [None]:
cog_category = pd.read_table('ref/cog-20.categories.tsv', names=['category', 'description'], index_col='category')

## Genes

In [None]:
sample_depth = xr.load_dataarray(f'data_temp/sp-{species_id}.{stemA}.midas_gene{centroid}.depth.nc')

## References

In [None]:
reference_meta = pd.read_table('ref_temp/uhgg_genomes_all_4644.tsv', index_col='Genome').rename_axis(index='genome_id')[lambda x: x.MGnify_accession == 'MGYG-HGUT-' + species_id[1:]].rename(lambda s: 'UHGG' + s[10:])
reference_meta.head()

In [None]:
reference_gene = xr.load_dataarray(f'data_temp/sp-{species_id}.midas_gene{centroid}.reference_copy_number.nc')
reference_gene = pd.DataFrame(reference_gene.T.values, index=reference_gene.gene_id, columns=reference_gene.genome_id)

In [None]:
isolate_gene = reference_gene[idxwhere(reference_meta.Genome_type == 'Isolate')]

# Select strains, genes

## QC Strains

In [None]:
strain_thresholds = (
    pd.read_table(f'data_temp/sp-{species_id}.{stemA}.gtpro.{stemB}.midas_gene{centroid}.strain_gene_threshold.tsv', index_col='strain')
    .rename(columns=dict(
        correlation_strict='corr_threshold_strict',
        correlation_moderate='corr_threshold_moderate',
        correlation_lenient='corr_threshold_lenient',
        depth_high='depth_thresh_high',
        depth_low='depth_thresh_low',
    ))
)

In [None]:
_strain_meta = (
    strain_thresholds
    .join(fit.genotype.entropy().to_series().rename('genotype_entropy'))
    .join(fit.metagenotype.entropy().to_series().rename('metagenotype_entropy').groupby(sample_to_strain).mean().rename(int))
    .join(strain_to_sample_list.apply(len).rename('num_samples'))
    .join(species_depth.apply(np.cbrt).groupby(sample_to_strain).std().rename('depth_stdev').rename(int))
    .join(species_depth.apply(np.cbrt).groupby(sample_to_strain).max().rename('depth_max').rename(int))
    .join(species_depth.apply(np.cbrt).groupby(sample_to_strain).sum().rename('depth_sum').rename(int))
    .assign(power_index=lambda x: (x.depth_stdev * np.sqrt(x.num_samples)).fillna(0))
)
strain_meta = _strain_meta
high_power_strain_list = idxwhere((strain_meta.power_index > 1.0) & (strain_meta.metagenotype_entropy < 0.05))
print(len(high_power_strain_list))
highest_power_strain_list = strain_meta.sort_values('power_index', ascending=False).head(3).index
strain_meta.sort_values('num_samples', ascending=False)

In [None]:
high_power_strain_palette = lib.plot.construct_ordered_palette(high_power_strain_list, mpl.cm.Spectral)

## Select Genes

In [None]:
strict_corr_hit = strain_corr > strain_meta.corr_threshold_strict
lenient_corr_hit = strain_corr > strain_meta.corr_threshold_lenient
low_corr =  strain_corr < strain_meta.corr_threshold_lenient
depth_hit = (strain_depth < strain_meta.depth_thresh_high) & (strain_depth > strain_meta.depth_thresh_low)
low_depth = (strain_depth < strain_meta.depth_thresh_low)
high_depth = (strain_depth > strain_meta.depth_thresh_high)
high_confidence_hit = depth_hit & strict_corr_hit
maybe_hit = depth_hit & lenient_corr_hit
low_depth_hit = low_depth & strict_corr_hit
high_depth_hit = high_depth & strict_corr_hit
ambiguous_hit = depth_hit ^ strict_corr_hit
high_confidence_not_hit = low_depth & low_corr

In [None]:
strain_meta.loc[high_power_strain_list]

In [None]:
high_confidence_hit[high_power_strain_list].sum()

In [None]:
samples_with_high_power_strains = idxwhere(fit.community.data.sel(strain=high_power_strain_list).sum("strain").to_series() > 0.5)
samples_without_high_power_strains = idxwhere(fit.community.data.sel(strain=high_power_strain_list).sum("strain").to_series() < 0.5)
len(samples_with_high_power_strains), len(samples_without_high_power_strains)

In [None]:
sf.plot.plot_genotype(
    fit.sel(strain=high_power_strain_list, position=position_ss),
    col_linkage_func=lambda w: w.metagenotype.linkage("position"),
    row_linkage_func=lambda w: w.genotype.linkage("strain"),
)

In [None]:
sf.plot.plot_metagenotype(
    fit.sel(sample=samples_with_high_power_strains, position=position_ss),
    col_linkage_func=lambda w: w.metagenotype.linkage(),
    row_linkage_func=lambda w: w.metagenotype.linkage("position"),
    col_colors=fit.sel(sample=samples_with_high_power_strains, position=position_ss).sample.to_series().map(sample_to_strain).map(high_power_strain_palette),
)

In [None]:
sf.plot.plot_community(
    fit.sel(sample=samples_with_high_power_strains, position=position_ss).drop_low_abundance_strains(0.05),
    col_linkage_func=lambda w: w.metagenotype.linkage(),
    row_linkage_func=lambda w: w.genotype.linkage("strain"),
    col_colors=fit.sel(sample=samples_with_high_power_strains, position=position_ss).sample.to_series().map(sample_to_strain).map(high_power_strain_palette),
)

In [None]:
strain_list = high_power_strain_list
gene_list = idxwhere(high_confidence_hit[strain_list].sum(1) > 0)

x = strain_depth.loc[gene_list, strain_list]

if len(gene_list) < 2e4:
    sns.clustermap(
        x,
        metric='cosine',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=0,
        xticklabels=0,
        col_linkage=fit.genotype.sel(strain=strain_list).linkage("strain"),
    )
else:
    print("Too many genes for clustermap:", len(gene_list))

print(len(gene_list), len(gene_list) - gene_annotation.loc[gene_list]['product'].value_counts()['hypothetical protein'])
print()
print(
    gene_annotation
    .loc[gene_list]
    .cog.to_frame()
    .join(cog_meta, on='cog')
    .pathway
    .value_counts()
    .sort_values(ascending=False)
    .head(10)
)
print()
print(
    gene_meta
    .loc[gene_list]
    ['product']
    .value_counts()
    .head(10)
)
print()
print(pd.merge(
    gene_annotation.loc[gene_list].cog.dropna().to_frame(),
    cog_x_category,
    on='cog',
).category.value_counts().to_frame().join(cog_category).head(10))

In [None]:
strain_list = high_power_strain_list
gene_list = idxwhere(high_confidence_hit[strain_list].mean(1) > 0.8)

x = strain_depth.loc[gene_list, strain_list]

if len(gene_list) < 2e4:
    sns.clustermap(
        x,
        metric='cosine',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=0,
        xticklabels=0,
        col_linkage=fit.genotype.sel(strain=strain_list).linkage("strain"),
    )
else:
    print("Too many genes for clustermap:", len(gene_list))

print(len(gene_list), len(gene_list) - gene_annotation.loc[gene_list]['product'].value_counts()['hypothetical protein'])
print()
print(
    gene_annotation
    .loc[gene_list]
    .cog.to_frame()
    .join(cog_meta, on='cog')
    .pathway
    .value_counts()
    .sort_values(ascending=False)
    .head(10)
)
print()
print(
    gene_meta
    .loc[gene_list]
    ['product']
    .value_counts()
    .head(10)
)
print()
print(pd.merge(
    gene_annotation.loc[gene_list].cog.dropna().to_frame(),
    cog_x_category,
    on='cog',
).category.value_counts().to_frame().join(cog_category).head(10))

In [None]:
strain_list = high_power_strain_list
gene_list = idxwhere((high_confidence_hit[strain_list].mean(1) > 0.05) & (high_confidence_not_hit[strain_list].mean(1) > 0.2))

x = strain_depth.loc[gene_list, strain_list]

if len(gene_list) < 2e4:
    sns.clustermap(
        x,
        metric='cosine',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=0,
        xticklabels=0,
        col_linkage=fit.genotype.sel(strain=strain_list).linkage("strain"),
    )
else:
    print("Too many genes for clustermap:", len(gene_list))

print(len(gene_list), len(gene_list) - gene_annotation.loc[gene_list]['product'].value_counts()['hypothetical protein'])
print()
print(
    gene_annotation
    .loc[gene_list]
    .cog.to_frame()
    .join(cog_meta, on='cog')
    .pathway
    .value_counts()
    .sort_values(ascending=False)
    .head(10)
)
print()
print(
    gene_meta
    .loc[gene_list]
    ['product']
    .value_counts()
    .head(10)
)
print()
print(pd.merge(
    gene_annotation.loc[gene_list].cog.dropna().to_frame(),
    cog_x_category,
    on='cog',
).category.value_counts().to_frame().join(cog_category).head(10))

# Broad strokes characterization of gene sets

## Phylogenetic conservation

In [None]:
strain_list = high_power_strain_list
gene_list = high_confidence_hit.index

m = gene_meta.join(cog_meta, on='cog', rsuffix='_cog')
x = high_confidence_hit.loc[gene_list, strain_list]

fdist = pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(x.T, metric='jaccard')), index=x.columns, columns=x.columns)
gdist = fit.genotype.sel(strain=strain_list).pdist()

d = pd.DataFrame(dict(
    genotype_distance=sp.spatial.distance.squareform(gdist),
    gene_content_distance=sp.spatial.distance.squareform(fdist)
))
plt.scatter('genotype_distance', 'gene_content_distance', data=d, s=5)
sp.stats.spearmanr(d.genotype_distance, d.gene_content_distance)

In [None]:
strain_list = high_power_strain_list
gene_list = idxwhere((high_confidence_hit[strain_list].mean(1) > 0.05) & (high_confidence_not_hit[strain_list].mean(1) > 0.2))

m = gene_meta.join(cog_meta, on='cog', rsuffix='_cog')
x = high_confidence_hit.loc[gene_list, strain_list]

fdist = pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(x.T, metric='jaccard')), index=x.columns, columns=x.columns)
gdist = fit.genotype.sel(strain=strain_list).pdist()

d = pd.DataFrame(dict(
    genotype_distance=sp.spatial.distance.squareform(gdist),
    gene_content_distance=sp.spatial.distance.squareform(fdist)
))
plt.scatter('genotype_distance', 'gene_content_distance', data=d, s=5)
sp.stats.spearmanr(d.genotype_distance, d.gene_content_distance)

In [None]:
tally_cog_category_reps = pd.merge(
    gene_meta.loc[idxwhere(high_confidence_hit[strain_list].any(1))].cog.value_counts().reset_index().rename(columns=dict(index='cog', cog='tally')),
    cog_x_category,
    on='cog'
).groupby('category').tally.sum().sort_values(ascending=False)
tally_cog_category_reps

In [None]:
fig, axs = plt.subplots(5, 3, figsize=(15, 19), sharex=True, sharey=True)

for this_cog_category, ax in zip(tally_cog_category_reps.index, axs.flatten()):
    ax.set_title(this_cog_category)
    strain_list = high_power_strain_list
    cog_list = cog_x_category[cog_x_category.category == this_cog_category].cog.unique()
    gene_list = gene_meta.cog.isin(cog_list)

    x = high_confidence_hit.loc[gene_list, strain_list]
    fdist = pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(x.T, metric='jaccard')), index=x.columns, columns=x.columns)

    gdist = fit.genotype.sel(strain=strain_list).pdist()

    d = pd.DataFrame(dict(
        genotype_distance=sp.spatial.distance.squareform(gdist),
        gene_content_distance=sp.spatial.distance.squareform(fdist)
    ))
    ax.scatter('genotype_distance', 'gene_content_distance', data=d, s=5)
    ax.annotate(np.round(sp.stats.spearmanr(d.genotype_distance, d.gene_content_distance)[0], 2), xy=(0.8, 0.9), xycoords='axes fraction')
    ax.annotate(int(x.mean(1).sum()), xy=(0.8, 0.7), xycoords='axes fraction')
    ax.annotate(cog_category.loc[this_cog_category].description, xy=(0.0, 0.8), xycoords='axes fraction')

## Gene Clusters

In [None]:
from sklearn.cluster import OPTICS, AgglomerativeClustering, MiniBatchKMeans, KMeans, Birch, OPTICS
from sklearn.mixture import GaussianMixture
from sklearn.decomposition import PCA

In [None]:
strain_hit_gene_list = idxwhere((high_confidence_hit[high_power_strain_list].sum(1) > 0))
strain_hit_gene_list = idxwhere((reference_gene > 0).loc[strain_hit_gene_list].sum(1) > 2)

In [None]:
%%time
# Expect ~7 minutes
x = (reference_gene > 0).loc[strain_hit_gene_list]
clust = AgglomerativeClustering(n_clusters=None, distance_threshold=0.1, linkage='average', affinity='cosine').fit(x)

In [None]:
reference_gene_clust = pd.Series(clust.labels_, index=strain_hit_gene_list)

In [None]:
reference_gene_clust_list = idxwhere(reference_gene_clust.value_counts() > 1)

In [None]:
reference_gene_clust.value_counts()[reference_gene_clust_list]

In [None]:
genes_in_reference_gene_clust_list = idxwhere(reference_gene_clust.isin(reference_gene_clust_list))

In [None]:
gene_clust_palette = lib.plot.construct_ordered_palette(reference_gene_clust_list)

In [None]:
strain_list = high_power_strain_list
gene_list = idxwhere(high_confidence_hit[high_power_strain_list].any(1))

x = strain_depth.loc[gene_list, strain_list]

if len(gene_list) < 2e4:
    sns.clustermap(
        x,
        metric='cosine',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=0,
        xticklabels=0,
        col_linkage=fit.genotype.sel(strain=strain_list).linkage("strain"),
        row_colors=reference_gene_clust.reindex(gene_list).map(gene_clust_palette),
    )
else:
    print("Too many genes for clustermap:", len(gene_list))

print(len(gene_list), len(gene_list) - gene_annotation.loc[gene_list]['product'].value_counts()['hypothetical protein'])
print()
print(
    gene_annotation
    .loc[gene_list]
    .cog.to_frame()
    .join(cog_meta, on='cog')
    .pathway
    .value_counts()
    .sort_values(ascending=False)
    .head(10)
)
print()
print(
    gene_meta
    .loc[gene_list]
    ['product']
    .value_counts()
    .head(10)
)
print()
print(pd.merge(
    gene_annotation.loc[gene_list].cog.dropna().to_frame(),
    cog_x_category,
    on='cog',
).category.value_counts().to_frame().join(cog_category).head(10))

In [None]:
strain_list = high_power_strain_list
gene_list = idxwhere(reference_gene_clust == reference_gene_clust_list[0])

x = strain_depth.loc[gene_list, strain_list]

if len(gene_list) < 2e4:
    sns.clustermap(
        x,
        metric='cosine',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=0,
        xticklabels=0,
        col_linkage=fit.genotype.sel(strain=strain_list).linkage("strain"),
    )
else:
    print("Too many genes for clustermap:", len(gene_list))
    
x2 = (high_confidence_hit.astype(int) + maybe_hit.astype(int) - high_confidence_not_hit.astype(int)).loc[gene_list, strain_list]

if len(gene_list) < 2e4:
    sns.clustermap(
        x2 + 1e-4,
        metric='cosine',
        # norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=0,
        xticklabels=0,
        col_linkage=fit.genotype.sel(strain=strain_list).linkage("strain"),
    )
else:
    print("Too many genes for clustermap:", len(gene_list))

print(len(gene_list), len(gene_list) - gene_annotation.loc[gene_list]['product'].value_counts()['hypothetical protein'])
print()
print(
    gene_annotation
    .loc[gene_list]
    .cog.to_frame()
    .join(cog_meta, on='cog')
    .pathway
    .value_counts()
    .sort_values(ascending=False)
    .head(10)
)
print()
print(
    gene_meta
    .loc[gene_list]
    ['product']
    .value_counts()
    .head(10)
)
print()
print(pd.merge(
    gene_annotation.loc[gene_list].cog.dropna().to_frame(),
    cog_x_category,
    on='cog',
).category.value_counts().to_frame().join(cog_category).head(10))

In [None]:
strain_list = high_power_strain_list
gene_list = idxwhere(reference_gene_clust == reference_gene_clust_list[1])

x = strain_depth.loc[gene_list, strain_list]

if len(gene_list) < 2e4:
    sns.clustermap(
        x,
        metric='cosine',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=0,
        xticklabels=0,
        col_linkage=fit.genotype.sel(strain=strain_list).linkage("strain"),
    )
else:
    print("Too many genes for clustermap:", len(gene_list))
    
x2 = (high_confidence_hit.astype(int) + maybe_hit.astype(int) - high_confidence_not_hit.astype(int)).loc[gene_list, strain_list]

if len(gene_list) < 2e4:
    sns.clustermap(
        x2 + 1e-4,
        metric='cosine',
        # norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=0,
        xticklabels=0,
        col_linkage=fit.genotype.sel(strain=strain_list).linkage("strain"),
    )
else:
    print("Too many genes for clustermap:", len(gene_list))

print(len(gene_list), len(gene_list) - gene_annotation.loc[gene_list]['product'].value_counts()['hypothetical protein'])
print()
print(
    gene_annotation
    .loc[gene_list]
    .cog.to_frame()
    .join(cog_meta, on='cog')
    .pathway
    .value_counts()
    .sort_values(ascending=False)
    .head(10)
)
print()
print(
    gene_meta
    .loc[gene_list]
    ['product']
    .value_counts()
    .head(10)
)
print()
print(pd.merge(
    gene_annotation.loc[gene_list].cog.dropna().to_frame(),
    cog_x_category,
    on='cog',
).category.value_counts().to_frame().join(cog_category).head(10))

In [None]:
strain_list = high_power_strain_list
gene_list = idxwhere(reference_gene_clust == reference_gene_clust_list[2])

x = strain_depth.loc[gene_list, strain_list]

if len(gene_list) < 2e4:
    sns.clustermap(
        x,
        metric='cosine',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=0,
        xticklabels=0,
        col_linkage=fit.genotype.sel(strain=strain_list).linkage("strain"),
    )
else:
    print("Too many genes for clustermap:", len(gene_list))
    
x2 = (high_confidence_hit.astype(int) + maybe_hit.astype(int) - high_confidence_not_hit.astype(int)).loc[gene_list, strain_list]

if len(gene_list) < 2e4:
    sns.clustermap(
        x2 + 1e-4,
        metric='cosine',
        # norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=0,
        xticklabels=0,
        col_linkage=fit.genotype.sel(strain=strain_list).linkage("strain"),
    )
else:
    print("Too many genes for clustermap:", len(gene_list))

print(len(gene_list), len(gene_list) - gene_annotation.loc[gene_list]['product'].value_counts()['hypothetical protein'])
print()
print(
    gene_annotation
    .loc[gene_list]
    .cog.to_frame()
    .join(cog_meta, on='cog')
    .pathway
    .value_counts()
    .sort_values(ascending=False)
    .head(10)
)
print()
print(
    gene_meta
    .loc[gene_list]
    ['product']
    .value_counts()
    .head(10)
)
print()
print(pd.merge(
    gene_annotation.loc[gene_list].cog.dropna().to_frame(),
    cog_x_category,
    on='cog',
).category.value_counts().to_frame().join(cog_category).head(10))

In [None]:
gene_clust_depth = gene_depth.to_series().unstack().groupby(reference_gene_clust).mean().rename(int).T
gene_clust_depth_trimmed_mean = gene_depth.to_series().unstack().groupby(reference_gene_clust).apply(lambda x: pd.Series(sp.stats.trim_mean(x, 0.2), index=x.columns)).rename(int).T

In [None]:
species_with_cluster = (((gene_clust_depth_trimmed_mean.T / species_depth) > 0.1) & (species_depth > 1e-3)).T
species_with_cluster_sample_frac = species_with_cluster.groupby(sample_meta.subject_id).mean()

In [None]:
depth = species_depth_by_subject
depth_colors = ((depth / depth.max())**(1/4)).map(mpl.cm.viridis)
diagnosis_colors = subject.loc[species_with_cluster_sample_frac.index].ibd_diagnosis.map({'CD': 'red', 'UC': 'pink', 'nonIBD': 'grey'})

sns.clustermap(
    species_with_cluster_sample_frac,
    metric='cosine',
    row_colors=depth_colors.to_frame('depth').assign(ibd=diagnosis_colors),
)

In [None]:
plt.plot(species_with_cluster_sample_frac.std().sort_values().values)
plt.axhline(0.31)

In [None]:
species_with_cluster_sample_frac.std().gt(0.2).sum()

In [None]:
d0 = species_with_cluster_sample_frac
d1 = d0.loc[:, d0.std() > 0.2]
y = subject.ibd_diagnosis.isin(['UC', 'CD'])

gene_cluster_test = d1.apply(lambda x: sp.stats.mannwhitneyu(x[y], x[~y])).T[1].rename('pvalue')

In [None]:
from statsmodels.stats.multitest import fdrcorrection

fdr_gene_cluster_test = pd.Series(fdrcorrection(gene_cluster_test)[1], index=gene_cluster_test.index)

fdr_gene_cluster_test.sort_values().head(30)

In [None]:
d0 = species_with_cluster_sample_frac
y = subject#.ibd_diagnosis.isin(['UC', 'CD'])

d1 = d0.rename_axis(columns='clust_id').unstack().to_frame('frac').reset_index().join(y, on='subject_id')

fig, ax = plt.subplots(figsize=(25, 10))
lib.plot.boxplot_with_points(
    'clust_id',
    'frac',
    hue='ibd_diagnosis',
    data=d1[d1.clust_id.isin(idxwhere(fdr_gene_cluster_test < 0.15))],
    dodge=True,
    ax=ax)

In [None]:
d = species_with_cluster_sample_frac[idxwhere(species_with_cluster_sample_frac.columns.to_series().isin(idxwhere(fdr_gene_cluster_test < 0.2)))]
depth = species_depth_by_subject
depth_colors = ((depth / depth.max())**(1/4)).map(mpl.cm.viridis)
diagnosis_colors = subject.loc[species_with_cluster_sample_frac.index].ibd_diagnosis.map({'CD': 'red', 'UC': 'pink', 'nonIBD': 'grey'})

sns.clustermap(
    d,
    metric='cosine',
    row_colors=depth_colors.to_frame('depth').assign(ibd=diagnosis_colors),
)

In [None]:
reference_gene_clust_list[0], reference_gene_clust_list[10], reference_gene_clust_list[2],

In [None]:
sns.clustermap(gene_clust_depth[[4, 174, 25]], norm=mpl.colors.LogNorm())

In [None]:
d = species_with_cluster_sample_frac[[4, 174, 25]]
depth = species_depth_by_subject
depth_colors = ((depth / depth.max())**(1/4)).map(mpl.cm.viridis)
diagnosis_colors = subject.loc[species_with_cluster_sample_frac.index].ibd_diagnosis.map({'CD': 'red', 'UC': 'pink', 'nonIBD': 'grey'})

sns.clustermap(
    d,
    # metric='cosine',
    row_colors=depth_colors.to_frame('depth').assign(ibd=diagnosis_colors),
)

In [None]:
reference_gene_clust[reference_gene_clust.isin(idxwhere(fdr_gene_cluster_test < 0.1))].value_counts()

In [None]:
gene_meta.loc[idxwhere(reference_gene_clust == 147)]

In [None]:
strain_list = high_power_strain_list
gene_list = idxwhere(reference_gene_clust == 147)

x = strain_depth.loc[gene_list, strain_list]

if len(gene_list) < 2e4:
    sns.clustermap(
        x,
        metric='cosine',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=0,
        xticklabels=0,
        col_linkage=fit.genotype.sel(strain=strain_list).linkage("strain"),
    )
else:
    print("Too many genes for clustermap:", len(gene_list))

print(len(gene_list), len(gene_list) - gene_annotation.loc[gene_list]['product'].value_counts()['hypothetical protein'])
print()
print(
    gene_annotation
    .loc[gene_list]
    .cog.to_frame()
    .join(cog_meta, on='cog')
    .pathway
    .value_counts()
    .sort_values(ascending=False)
    .head(10)
)
print()
print(
    gene_meta
    .loc[gene_list]
    ['product']
    .value_counts()
    .head(10)
)
print()
print(pd.merge(
    gene_annotation.loc[gene_list].cog.dropna().to_frame(),
    cog_x_category,
    on='cog',
).category.value_counts().to_frame().join(cog_category).head(10))

In [None]:
strain_list = high_power_strain_list
gene_list = idxwhere(reference_gene_clust == 602)

x = strain_depth.loc[gene_list, strain_list]

if len(gene_list) < 2e4:
    sns.clustermap(
        x,
        metric='cosine',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=0,
        xticklabels=0,
        col_linkage=fit.genotype.sel(strain=strain_list).linkage("strain"),
    )
else:
    print("Too many genes for clustermap:", len(gene_list))

print(len(gene_list), len(gene_list) - gene_annotation.loc[gene_list]['product'].value_counts()['hypothetical protein'])
print()
print(
    gene_annotation
    .loc[gene_list]
    .cog.to_frame()
    .join(cog_meta, on='cog')
    .pathway
    .value_counts()
    .sort_values(ascending=False)
    .head(10)
)
print()
print(
    gene_meta
    .loc[gene_list]
    ['product']
    .value_counts()
    .head(10)
)
print()
print(pd.merge(
    gene_annotation.loc[gene_list].cog.dropna().to_frame(),
    cog_x_category,
    on='cog',
).category.value_counts().to_frame().join(cog_category).head(10))

In [None]:
strain_list = high_power_strain_list
gene_list = idxwhere(reference_gene_clust == 147)

x = strain_depth.loc[gene_list, strain_list]

if len(gene_list) < 2e4:
    sns.clustermap(
        x,
        metric='cosine',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=0,
        xticklabels=0,
        col_linkage=fit.genotype.sel(strain=strain_list).linkage("strain"),
    )
else:
    print("Too many genes for clustermap:", len(gene_list))

print(len(gene_list), len(gene_list) - gene_annotation.loc[gene_list]['product'].value_counts()['hypothetical protein'])
print()
print(
    gene_annotation
    .loc[gene_list]
    .cog.to_frame()
    .join(cog_meta, on='cog')
    .pathway
    .value_counts()
    .sort_values(ascending=False)
    .head(10)
)
print()
print(
    gene_meta
    .loc[gene_list]
    ['product']
    .value_counts()
    .head(10)
)
print()
print(pd.merge(
    gene_annotation.loc[gene_list].cog.dropna().to_frame(),
    cog_x_category,
    on='cog',
).category.value_counts().to_frame().join(cog_category).head(10))

In [None]:
plt.hist(fit.community.sel(strain=high_power_strain_list).sum("strain"))

In [None]:
x = species_depth / (sample_meta.sequenced_reads / 1e6)
y = gene_clust_depth_trimmed_mean[496] / (sample_meta.sequenced_reads / 1e6)

plt.scatter(x, y, s=5)
plt.yscale('symlog', linthresh=1e-5)
plt.xscale('symlog', linthresh=1e-5)

In [None]:
d = high_confidence_hit[high_power_strain_list].groupby(reference_gene_clust).mean().rename(int).loc[reference_gene_clust_list]

sns.clustermap(
    d + 1e-4,
    metric='cosine',
    col_linkage=fit.genotype.sel(strain=d.columns).linkage(),
    col_colors=strain_meta.loc[d.columns].power_index.pipe(lambda x: x / x.max()).pipe(mpl.cm.viridis),
    row_colors=reference_gene_clust.value_counts().loc[reference_gene_clust_list].pipe(np.log10).pipe(lambda x: x / x.max()).pipe(mpl.cm.viridis),
    # figsize=(5, 50),
)

In [None]:
high_confidence_hit[high_power_strain_list].groupby(reference_gene_clust).mean().rename(int).assign(clust_size=reference_gene_clust.value_counts()).sort_values('clust_size', ascending=False).head(10)

In [None]:
maybe_hit[high_power_strain_list].groupby(reference_gene_clust).mean().rename(int).assign(clust_size=reference_gene_clust.value_counts()).sort_values('clust_size', ascending=False).head(10)

In [None]:
# strain_id = 8
gene_freq = (reference_gene > 0).mean(1).sort_values()
# strain_content = high_confidence_hit[strain_id]

window_size = 0.01

def window_agg(df, by, width):
    df = df.assign(__window_idx=(df[by] / width).round())
    return df.groupby('__window_idx').agg(['mean', 'count']).rename(lambda x: x * width).rename_axis(index=by)
    
for strain_id in high_power_strain_list:
    d = (
        window_agg(
            pd.DataFrame(dict(strain_content=high_confidence_hit[strain_id].astype(float).reindex(gene_freq.index, fill_value=0), gene_freq=gene_freq)),
            by='gene_freq',
            width=0.05
        )
        .strain_content
        .reset_index()
        .assign(delta=lambda d: d['gene_freq'] - d['mean'])
        .assign(sqrt_count=lambda x: np.sqrt(x['count']))
    )
    plt.plot('gene_freq', 'mean', data=d, alpha=0.4, label='__none__')
    plt.scatter('gene_freq', 'mean', data=d, s='sqrt_count', linewidths=1, alpha=0.2, label=(strain_id, round(strain_meta.power_index[strain_id], 1)))
    plt.plot([0, 1], [0, 1], lw=1, linestyle='--', color='k')
plt.legend(bbox_to_anchor=(1, 1))

In [None]:
# strain_id = 8
gene_freq = (reference_gene > 0).mean(1).sort_values()
# strain_content = high_confidence_hit[strain_id]

window_size = 0.01

def window_agg(df, by, width):
    df = df.assign(__window_idx=(df[by] / width).round())
    return df.groupby('__window_idx').agg(['mean', 'count']).rename(lambda x: x * width).rename_axis(index=by)
    
for strain_id in idxwhere(strain_meta.power_index < 1):
    d = (
        window_agg(
            pd.DataFrame(dict(strain_content=high_confidence_hit[strain_id].astype(float).reindex(gene_freq.index, fill_value=0), gene_freq=gene_freq)),
            by='gene_freq',
            width=0.05
        )
        .strain_content
        .reset_index()
        .assign(delta=lambda d: d['gene_freq'] - d['mean'])
        .assign(sqrt_count=lambda x: np.sqrt(x['count']))
    )
    plt.plot('gene_freq', 'mean', data=d, alpha=0.4, label='__none__')
    plt.scatter('gene_freq', 'mean', data=d, s='sqrt_count', linewidths=1, alpha=0.2, label=(strain_id, round(strain_meta.power_index[strain_id], 1)))
    plt.plot([0, 1], [0, 1], lw=1, linestyle='--', color='k')
plt.legend(bbox_to_anchor=(1, 1))

In [None]:
# strain_id = 8
gene_freq = (reference_gene > 0).mean(1).sort_values()
# strain_content = high_confidence_hit[strain_id]

window_size = 0.01

def window_agg(df, by, width):
    df = df.assign(__window_idx=(df[by] / width).round())
    return df.groupby('__window_idx').agg(['mean', 'count']).rename(lambda x: x * width).rename_axis(index=by)
    
for strain_id in reference_gene.columns.to_series().sample(10):
    d = (
        window_agg(
            pd.DataFrame(dict(strain_content=(reference_gene > 0)[strain_id].astype(float).reindex(gene_freq.index, fill_value=0), gene_freq=gene_freq)),
            by='gene_freq',
            width=0.05
        )
        .strain_content
        .reset_index()
        .assign(delta=lambda d: d['gene_freq'] - d['mean'])
        .assign(sqrt_count=lambda x: np.sqrt(x['count']))
    )
    plt.plot('gene_freq', 'mean', data=d, alpha=0.4, label='__none__')
    plt.scatter('gene_freq', 'mean', data=d, s='sqrt_count', linewidths=1, alpha=0.2, label=strain_id)
    plt.plot([0, 1], [0, 1], lw=1, linestyle='--', color='k')
plt.legend(bbox_to_anchor=(1, 1))

In [None]:
# strain_id = 8
gene_freq = (reference_gene > 0).mean(1).sort_values()
# strain_content = high_confidence_hit[strain_id]

window_size = 0.01

def window_agg(df, by, width):
    df = df.assign(__window_idx=(df[by] / width).round())
    return df.groupby('__window_idx').agg(['mean', 'count']).rename(lambda x: x * width).rename_axis(index=by)
    
for strain_id in isolate_gene.columns.to_series().sample(10):
    d = (
        window_agg(
            pd.DataFrame(dict(strain_content=(reference_gene > 0)[strain_id].astype(float).reindex(gene_freq.index, fill_value=0), gene_freq=gene_freq)),
            by='gene_freq',
            width=0.05
        )
        .strain_content
        .reset_index()
        .assign(delta=lambda d: d['gene_freq'] - d['mean'])
        .assign(sqrt_count=lambda x: np.sqrt(x['count']))
    )
    plt.plot('gene_freq', 'mean', data=d, alpha=0.4, label='__none__')
    plt.scatter('gene_freq', 'mean', data=d, s='sqrt_count', linewidths=1, alpha=0.2, label=strain_id)
    plt.plot([0, 1], [0, 1], lw=1, linestyle='--', color='k')
plt.legend(bbox_to_anchor=(1, 1))