In [None]:
%load_ext autoreload

In [None]:
import os as _os
_os.chdir(_os.environ['PROJECT_ROOT'])
_os.path.realpath(_os.path.curdir)

In [None]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
from lib.pandas_util import idxwhere, align_indexes, invert_mapping
import matplotlib as mpl
import lib.plot
import statsmodels as sm
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm
import subprocess
from tempfile import mkstemp
import time
import subprocess
from itertools import chain

In [None]:
import sfacts as sf

In [None]:
sns.set_context('talk')
plt.rcParams['figure.dpi'] = 75

In [None]:
# sns.set_context('talk')
# plt.rcParams['figure.dpi'] = 100
group = 'hmp2'
centroid = 75
stemA = 'r.proc'
stemB = 'filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts33-s80-seed0'
# stemB = 'filt-poly05-cvrg05.ss-g10000-block0-seed0.approx-clust2-thresh05-s95'

# Analysis Parameters

In [None]:
species_id = '102506'
# depth_ratio_bound = 3
# species_gene_corr_thresh = 0.99
# n_species_genes = 2000
species_gene_corr_thresh = 0.98

# Prepare Data

## Taxonomy

In [None]:
species_taxonomy = pd.read_table('ref/gtpro/species_taxonomy_ext.tsv', names=['genome_id', 'species_id', 'taxonomy_string']).assign(species_id=lambda x: x.species_id.astype(str)).set_index('species_id')[['taxonomy_string']].assign(taxonomy_split=lambda x: x.taxonomy_string.str.split(';'))

for level_name, level_number in [('p__', 1), ('c__', 2), ('o__', 3), ('f__', 4), ('g__', 5), ('s__', 6)]:
    species_taxonomy = species_taxonomy.assign(**{level_name: species_taxonomy.taxonomy_split.apply(lambda x: x[level_number])}) 
species_taxonomy = species_taxonomy.drop(columns=['taxonomy_split'])

species_taxonomy.loc[species_id]

In [None]:
midasdb_genomes = pd.read_table('ref/midasdb_uhgg/genomes.tsv')

In [None]:
midasdb_genomes.species.value_counts()

## Species

In [None]:
all_species_depth = pd.read_table(f'data/group/{group}/{stemA}.gtpro.species_depth.tsv', index_col=['sample', 'species_id']).squeeze().unstack('species_id', fill_value=0).rename(str, axis='columns')
species_rabund = all_species_depth.divide(all_species_depth.sum(1), axis=0)

In [None]:
gtpro_species_depth = pd.read_table(f'data/group/{group}/species/sp-{species_id}/{stemA}.gtpro.species_depth.tsv', dtype=dict(sample=str, species_id=str, depth=float), index_col=['sample', 'species_id']).squeeze().unstack('species_id')

In [None]:
species_depth = pd.read_table(f'data/group/{group}/species/sp-{species_id}/{stemA}.midas_gene{centroid}.species_depth.tsv', names=['sample', 'depth'], index_col='sample').squeeze()

In [None]:
gene_depth = xr.load_dataarray(f'data/group/{group}/species/sp-{species_id}/{stemA}.midas_gene{centroid}.depth.nc').sel(sample=species_depth.index)

In [None]:
d = pd.DataFrame(dict(gtpro=gtpro_species_depth[species_id], midas=species_depth))

plt.scatter('gtpro', 'midas', data=d, s=3, alpha=0.3)
plt.plot([0, 1e2], [0, 1e2])
plt.yscale('symlog', linthresh=1e-4)
plt.xscale('symlog', linthresh=1e-4)

In [None]:
d = pd.DataFrame(dict(gtpro=gtpro_species_depth[species_id], midas=species_depth))

plt.scatter('gtpro', 'midas', data=np.cbrt(d), s=3, alpha=0.3)
# plt.plot([0, 1e2], [0, 1e2])
# plt.yscale('symlog', linthresh=1e-4)
# plt.xscale('symlog', linthresh=1e-4)

In [None]:
species_corr = pd.read_table(f'data/group/{group}/species/sp-{species_id}/{stemA}.midas_gene{centroid}.species_correlation.tsv', names=['sample', 'correlation'], index_col='sample').squeeze()

## Metadata

In [None]:
mgen = pd.read_table('meta/hmp2/mgen.tsv', index_col='library_id')
preparation = pd.read_table('meta/hmp2/preparation.tsv', index_col='preparation_id')
stool = pd.read_table('meta/hmp2/stool.tsv', index_col='stool_id')
subject = pd.read_table('meta/hmp2/subject.tsv', index_col='subject_id')

sample_meta = mgen.join(preparation, on='preparation_id', rsuffix='_').join(stool, on='stool_id').join(subject, on='subject_id').loc[all_species_depth.index]

In [None]:
len(sample_meta.stool_id.unique()), len(sample_meta.subject_id.unique())

## Strains

In [None]:
fit = sf.World.load(
    f'data/group/{group}/species/sp-{species_id}/{stemA}.gtpro.{stemB}.world.nc'
).drop_low_abundance_strains(0.05)
print(fit.sizes)

np.random.seed(0)
position_ss = fit.random_sample(position=min(fit.sizes['position'], 1000)).position

In [None]:
strain_corr = pd.read_table(
    f'data/group/{group}/species/sp-{species_id}/{stemA}.gtpro.{stemB}.midas_gene{centroid}.strain_correlation.tsv',
    index_col=['gene_id', 'strain']
).squeeze().unstack(fill_value=0)
# strain_corr = strain_by_species_corr.sel(species_id=species_id).to_series().unstack('strain')
strain_depth = pd.read_table(
    f'data/group/{group}/species/sp-{species_id}/{stemA}.gtpro.{stemB}.midas_gene{centroid}.strain_depth_ratio.tsv',
    index_col=['gene_id', 'strain']
).squeeze().unstack()
strain_corr, strain_depth = align_indexes(*align_indexes(strain_corr, strain_depth), axis="columns")

In [None]:
sample_to_strain = (
    (fit.community.data > 0.95)
    .to_series()
    .unstack()
    .apply(idxwhere, axis=1)
    [lambda x: x.apply(bool)]
    .str[0]
    .rename('strain')
)
    
strain_to_sample_list = (
    sample_to_strain
    .rename('strain_id')
    .reset_index()
    .groupby('strain_id')
    .apply(lambda x: x['sample'].to_list())
)
strain_to_sample_list.apply(len).sort_values(ascending=False).head()

In [None]:
# species_gene_corr_thresh = species_corr.sort_values(ascending=False).head(n_species_genes + 1).min()
species_gene_list = idxwhere(species_corr.loc[strain_corr.index] > species_gene_corr_thresh)
print(len(species_gene_list))

In [None]:
# strain_thresh = pd.read_table(
#     f'data_temp/sp-{species_id}.{stemA}.gtpro.{stemB}.midas_gene{centroid}.strain_correlation_threshold.tsv',
#     names=['strain_id', 'threshold'],
#     index_col='strain_id',
# ).loc[strain_corr.columns]
plt.hist(species_corr, bins=np.linspace(0, 1, num=101))
plt.axvline(species_gene_corr_thresh, linestyle=':', color='k')
plt.yscale('log')

In [None]:
strain_frac = pd.read_table(f'data/group/{group}/species/sp-{species_id}/{stemA}.gtpro.{stemB}.comm.tsv', index_col=['sample', 'strain']).squeeze().unstack(fill_value=0)

In [None]:
plt.hist(np.log10(species_depth[species_depth > 0]), bins=np.linspace(-4, 4))
plt.axvline(np.log10(1.0), linestyle=':', color='k')

## MIDAS Genes, COGs COG categories

In [None]:
gene_cluster = pd.read_table(
    f'ref/midasdb_uhgg/pangenomes/{species_id}/cluster_info.txt'
).set_index('centroid_99', drop=False).rename_axis(index='gene_id')
gene_annotation = pd.read_table(
    f'ref/midasdb_uhgg_gene_annotations/sp-{species_id}.gene{centroid}_annotations.tsv',
    names=['locus_tag', 'ftype', 'length_bp', 'gene', 'EC_number', 'COG', 'product'],
    index_col='locus_tag',
).rename(columns=str.lower)

gene_meta = gene_cluster.loc[gene_cluster[f'centroid_{centroid}'].unique()].join(gene_annotation)

In [None]:
gene_cluster

In [None]:
_cog_meta = pd.read_table(
    'ref/cog-20.meta.tsv',
    names=['cog', 'categories', 'description', 'gene', 'pathway', '_1', '_2'],
    index_col=['cog']
)
cog_meta = _cog_meta.drop(columns=['categories', '_1', '_2'])
cog_x_category = _cog_meta.categories.apply(tuple).apply(pd.Series).unstack().to_frame(name='category').reset_index()[['cog', 'category']].dropna()

In [None]:
cog_category = pd.read_table('ref/cog-20.categories.tsv', names=['category', 'description'], index_col='category')

## Genes

## References

In [None]:
reference_meta = pd.read_table('ref/uhgg_genomes_all_4644.tsv', index_col='Genome').rename_axis(index='genome_id')[lambda x: x.MGnify_accession == 'MGYG-HGUT-' + species_id[1:]].rename(lambda s: 'UHGG' + s[10:])
reference_meta.head()

In [None]:
reference_gene = xr.load_dataarray(f'ref/midasdb_uhgg_pangenomes/{species_id}/midas_gene{centroid}.reference_copy_number.nc')
reference_gene = pd.DataFrame(reference_gene.T.values, index=reference_gene.gene_id, columns=reference_gene.genome_id)

In [None]:
isolate_gene = reference_gene[idxwhere(reference_meta.Genome_type == 'Isolate')]

# QC Strains

In [None]:
strain_thresholds = (
    pd.read_table(f"data/group/{group}/species/sp-{species_id}/{stemA}.gtpro.{stemB}.midas_gene{centroid}.strain_gene_threshold.tsv", index_col='strain')
    .rename(columns=dict(
        correlation_strict='corr_threshold_strict',
        correlation_moderate='corr_threshold_moderate',
        correlation_lenient='corr_threshold_lenient',
        depth_high='depth_thresh_high',
        depth_low='depth_thresh_low',
    ))
)

In [None]:
_strain_meta = (
    strain_thresholds
    .join(fit.genotype.entropy().to_series().rename('genotype_entropy'))
    .join(fit.metagenotype.entropy().to_series().rename('metagenotype_entropy').groupby(sample_to_strain).mean().rename(int))
    .join(strain_to_sample_list.apply(len).rename('num_samples'))
    .join(species_depth.apply(np.cbrt).groupby(sample_to_strain).std().rename('depth_stdev').rename(int))
    .join(species_depth.apply(np.cbrt).groupby(sample_to_strain).max().rename('depth_max').rename(int))
    .join(species_depth.apply(np.cbrt).groupby(sample_to_strain).sum().rename('depth_sum').rename(int))
    .assign(power_index=lambda x: (x.depth_stdev * np.sqrt(x.num_samples)).fillna(0))
)
strain_meta = _strain_meta
high_power_strain_list = idxwhere((strain_meta.power_index > 1.0) & (strain_meta.metagenotype_entropy < 0.05))
print(len(high_power_strain_list))
highest_power_strain_list = strain_meta.sort_values('power_index', ascending=False).head(3).index
strain_meta.sort_values('num_samples', ascending=False)

# Select Genes

In [None]:
strict_corr_hit = strain_corr > strain_meta.corr_threshold_strict
lenient_corr_hit = strain_corr > strain_meta.corr_threshold_lenient
moderate_corr_hit = strain_corr > strain_meta.corr_threshold_moderate
low_corr =  strain_corr < strain_meta.corr_threshold_lenient

low_depth = (strain_depth < strain_meta.depth_thresh_low)
depth_hit = ~low_depth
high_depth = (strain_depth > strain_meta.depth_thresh_high)
high_confidence_hit = depth_hit & strict_corr_hit
moderate_hit = depth_hit & moderate_corr_hit
maybe_hit = depth_hit & lenient_corr_hit
low_depth_hit = low_depth & strict_corr_hit
high_depth_hit = high_depth & strict_corr_hit
ambiguous_hit = depth_hit ^ strict_corr_hit
high_confidence_not_hit = low_depth & low_corr

In [None]:
high_confidence_hit[high_power_strain_list]

In [None]:
samples_with_high_power_strains = idxwhere(fit.community.data.sel(strain=high_power_strain_list).sum("strain").to_series() > 0.5)
samples_without_high_power_strains = idxwhere(fit.community.data.sel(strain=high_power_strain_list).sum("strain").to_series() < 0.5)
len(samples_with_high_power_strains), len(samples_without_high_power_strains)

In [None]:
sf.plot.plot_genotype(
    fit.sel(strain=high_power_strain_list, position=position_ss),
    col_linkage_func=lambda w: w.metagenotype.linkage("position"),
    row_linkage_func=lambda w: w.genotype.linkage("strain"),
)

In [None]:
sf.plot.plot_metagenotype(
    fit.sel(sample=samples_with_high_power_strains, position=position_ss),
    col_linkage_func=lambda w: w.metagenotype.linkage(),
    row_linkage_func=lambda w: w.metagenotype.linkage("position"),
)

In [None]:
sf.plot.plot_community(
    fit.sel(sample=samples_with_high_power_strains, position=position_ss).drop_low_abundance_strains(0.05),
    col_linkage_func=lambda w: w.metagenotype.linkage(),
    row_linkage_func=lambda w: w.genotype.linkage("strain"),
)

# Phylogenetic conservation

In [None]:
strain_list = high_power_strain_list

m = gene_meta.join(cog_meta, on='cog', rsuffix='_cog')
x = high_confidence_hit[strain_list]
fdist = pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(x.T, metric='jaccard')), index=x.columns, columns=x.columns)

gdist = fit.genotype.sel(strain=strain_list).pdist()

d = pd.DataFrame(dict(
    genotype_distance=sp.spatial.distance.squareform(gdist),
    gene_content_distance=sp.spatial.distance.squareform(fdist)
))
plt.scatter('genotype_distance', 'gene_content_distance', data=d, s=5)
sp.stats.spearmanr(d.genotype_distance, d.gene_content_distance)

In [None]:
strain_list = high_power_strain_list

m = gene_meta.join(cog_meta, on='cog', rsuffix='_cog')
x = high_confidence_hit[strain_list].groupby(m.cog).any()
fdist = pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(x.T, metric='jaccard')), index=x.columns, columns=x.columns)

gdist = fit.genotype.sel(strain=strain_list).pdist()

d = pd.DataFrame(dict(
    genotype_distance=sp.spatial.distance.squareform(gdist),
    gene_content_distance=sp.spatial.distance.squareform(fdist)
))
plt.scatter('genotype_distance', 'gene_content_distance', data=d, s=5)
sp.stats.spearmanr(d.genotype_distance, d.gene_content_distance)

In [None]:
strain_list = high_power_strain_list

m = gene_meta.join(cog_meta, on='cog', rsuffix='_cog')
x = high_confidence_hit[strain_list].groupby(m.pathway).any()
fdist = pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(x.T, metric='jaccard')), index=x.columns, columns=x.columns)

gdist = fit.genotype.sel(strain=strain_list).pdist()

d = pd.DataFrame(dict(
    genotype_distance=sp.spatial.distance.squareform(gdist),
    gene_content_distance=sp.spatial.distance.squareform(fdist)
))
plt.scatter('genotype_distance', 'gene_content_distance', data=d, s=5)
sp.stats.spearmanr(d.genotype_distance, d.gene_content_distance)

In [None]:
strain_list = high_power_strain_list

tally_cog_category_reps = pd.merge(
    gene_meta.loc[idxwhere(high_confidence_hit[strain_list].any(1))].cog.value_counts().reset_index().rename(columns=dict(index='cog', cog='tally')),
    cog_x_category,
    on='cog'
).groupby('category').tally.sum().sort_values(ascending=False)
tally_cog_category_reps

In [None]:
fig, axs = plt.subplots(5, 3, figsize=(15, 19), sharex=True, sharey=True)

for this_cog_category, ax in zip(tally_cog_category_reps.index, axs.flatten()):
    ax.set_title(this_cog_category)
    strain_list = high_power_strain_list
    cog_list = cog_x_category[cog_x_category.category == this_cog_category].cog.unique()
    gene_list = gene_meta.cog.isin(cog_list)

    x = high_confidence_hit.loc[gene_list, strain_list]
    fdist = pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(x.T, metric='jaccard')), index=x.columns, columns=x.columns)

    gdist = fit.genotype.sel(strain=strain_list).pdist()

    d = pd.DataFrame(dict(
        genotype_distance=sp.spatial.distance.squareform(gdist),
        gene_content_distance=sp.spatial.distance.squareform(fdist)
    ))
    ax.scatter('genotype_distance', 'gene_content_distance', data=d, s=5)
    ax.annotate(np.round(sp.stats.spearmanr(d.genotype_distance, d.gene_content_distance)[0], 2), xy=(0.8, 0.9), xycoords='axes fraction')
    ax.annotate(int(x.mean(1).sum()), xy=(0.8, 0.7), xycoords='axes fraction')
    ax.annotate(cog_category.loc[this_cog_category].description, xy=(0.0, 0.8), xycoords='axes fraction')

In [None]:
strain_list = high_power_strain_list
gene_list = idxwhere((moderate_hit[strain_list].mean(1) > 0.2) & (high_confidence_not_hit[strain_list].mean(1) > 0.2))

m = gene_meta.join(cog_meta, on='cog', rsuffix='_cog')
x = high_confidence_hit.loc[gene_list, strain_list]

fdist = pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(x.T, metric='jaccard')), index=x.columns, columns=x.columns)
gdist = fit.genotype.sel(strain=strain_list).pdist()

d = pd.DataFrame(dict(
    genotype_distance=sp.spatial.distance.squareform(gdist),
    gene_content_distance=sp.spatial.distance.squareform(fdist)
))
plt.scatter('genotype_distance', 'gene_content_distance', data=d, s=5)
sp.stats.spearmanr(d.genotype_distance, d.gene_content_distance)

In [None]:
strain_list = high_power_strain_list
gene_list = idxwhere(moderate_hit[strain_list].mean(1) > 0.7)

x = strain_depth.loc[gene_list, strain_list]

if len(gene_list) < 2e4:
    sns.clustermap(
        x,
        metric='cosine',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=0,
        xticklabels=0,
        col_linkage=fit.genotype.sel(strain=strain_list).linkage("strain"),
    )
else:
    print("Too many genes for clustermap:", len(gene_list))

print(len(gene_list), len(gene_list) - gene_annotation.loc[gene_list]['product'].value_counts()['hypothetical protein'])
print()
print(
    gene_annotation
    .loc[gene_list]
    .cog.to_frame()
    .join(cog_meta, on='cog')
    .pathway
    .value_counts()
    .sort_values(ascending=False)
    .head(10)
)
print()
print(
    gene_meta
    .loc[gene_list]
    ['product']
    .value_counts()
    .head(10)
)
print()
print(pd.merge(
    gene_annotation.loc[gene_list].cog.dropna().to_frame(),
    cog_x_category,
    on='cog',
).category.value_counts().to_frame().join(cog_category).head(10))

In [None]:
strain_list = high_power_strain_list
gene_list = idxwhere((moderate_hit[strain_list].mean(1) > 0.05) & (high_confidence_not_hit[strain_list].mean(1) > 0.2))

x = strain_depth.loc[gene_list, strain_list]

if len(gene_list) < 2e4:
    sns.clustermap(
        x,
        metric='cosine',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=0,
        xticklabels=0,
        col_linkage=fit.genotype.sel(strain=strain_list).linkage("strain"),
    )
else:
    print("Too many genes for clustermap:", len(gene_list))

print(len(gene_list), len(gene_list) - gene_annotation.loc[gene_list]['product'].value_counts()['hypothetical protein'])
print()
print(
    gene_annotation
    .loc[gene_list]
    .cog.to_frame()
    .join(cog_meta, on='cog')
    .pathway
    .value_counts()
    .sort_values(ascending=False)
    .head(10)
)
print()
print(
    gene_meta
    .loc[gene_list]
    ['product']
    .value_counts()
    .head(10)
)
print()
print(pd.merge(
    gene_annotation.loc[gene_list].cog.dropna().to_frame(),
    cog_x_category,
    on='cog',
).category.value_counts().to_frame().join(cog_category).head(10))

In [None]:
strain_list = high_power_strain_list
gene_list = idxwhere(high_depth_hit[strain_list].mean(1) > 0.1)

x = strain_depth.loc[gene_list, strain_list]

if len(gene_list) < 2e4:
    sns.clustermap(
        x,
        metric='cosine',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=0,
        xticklabels=0,
        col_linkage=fit.genotype.sel(strain=strain_list).linkage("strain"),
    )
else:
    print("Too many genes for clustermap:", len(gene_list))

print(len(gene_list), len(gene_list) - gene_annotation.loc[gene_list]['product'].value_counts()['hypothetical protein'])
print()
print(
    gene_annotation
    .loc[gene_list]
    .cog.to_frame()
    .join(cog_meta, on='cog')
    .pathway
    .value_counts()
    .sort_values(ascending=False)
    .head(10)
)
print()
print(
    gene_meta
    .loc[gene_list]
    ['product']
    .value_counts()
    .head(10)
)
print()
print(pd.merge(
    gene_annotation.loc[gene_list].cog.dropna().to_frame(),
    cog_x_category,
    on='cog',
).category.value_counts().to_frame().join(cog_category).head(10))

In [None]:
strain_list = high_power_strain_list
gene_list = idxwhere(high_confidence_hit[strain_list].mean(1) > 0.05)

m = gene_meta.join(cog_meta, on='cog', rsuffix='_cog')
x = high_confidence_hit.loc[gene_list, strain_list]

fdist = pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(x.T, metric='jaccard')), index=x.columns, columns=x.columns)
gdist = fit.genotype.sel(strain=strain_list).pdist()

d = pd.DataFrame(dict(
    genotype_distance=sp.spatial.distance.squareform(gdist),
    gene_content_distance=sp.spatial.distance.squareform(fdist)
))
plt.scatter('genotype_distance', 'gene_content_distance', data=d, s=5)
sp.stats.spearmanr(d.genotype_distance, d.gene_content_distance)

In [None]:
strain_list = high_power_strain_list
gene_list = idxwhere(moderate_hit[strain_list].mean(1) > 0.05)

m = gene_meta.join(cog_meta, on='cog', rsuffix='_cog')
x = moderate_hit.loc[gene_list, strain_list]

fdist = pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(x.T, metric='jaccard')), index=x.columns, columns=x.columns)
gdist = fit.genotype.sel(strain=strain_list).pdist()

d = pd.DataFrame(dict(
    genotype_distance=sp.spatial.distance.squareform(gdist),
    gene_content_distance=sp.spatial.distance.squareform(fdist)
))
plt.scatter('genotype_distance', 'gene_content_distance', data=d, s=5)
sp.stats.spearmanr(d.genotype_distance, d.gene_content_distance)

In [None]:
cog_to_category_list = (
    cog_x_category
    .groupby('cog')
    .apply(lambda x: list(x.category.unique()))
    .to_frame("categories")
    .assign(
        num_categories=lambda x: x.categories.str.len(),
        first_category=lambda x: x.categories.str[0]
    )
)
cog_to_category_list.loc[idxwhere(cog_to_category_list.num_categories > 1)]
cog_to_category = cog_to_category_list.first_category
cog_category_order = cog_to_category.unique()
cog_category_palette = lib.plot.construct_ordered_palette(cog_category_order)

In [None]:
strain_list = high_power_strain_list
gene_list = idxwhere((moderate_hit & ~high_depth_hit)[strain_list].mean(1) > 0.15)

x = strain_depth.loc[gene_list, strain_list]

if len(gene_list) < 2e4:
    sns.clustermap(
        x,
        metric='cosine',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=0,
        xticklabels=0,
        # row_colors=x.index.to_series().map(gene_annotation.cog).map(cog_to_category).map(cog_category_palette),
        col_linkage=fit.genotype.sel(strain=strain_list).linkage("strain"),
        cbar_pos=(0.05, 0.95, 0.1, 0.05), cbar_kws=dict(orientation='horizontal'),
        
    )
else:
    print("Too many genes for clustermap:", len(gene_list))

print(len(gene_list), len(gene_list) - gene_annotation.loc[gene_list]['product'].value_counts()['hypothetical protein'])
print()
print(
    gene_annotation
    .loc[gene_list]
    .cog.to_frame()
    .join(cog_meta, on='cog')
    .pathway
    .value_counts()
    .sort_values(ascending=False)
    .head(10)
)
print()
print(
    gene_meta
    .loc[gene_list]
    ['product']
    .value_counts()
    .head(10)
)
print()
print(pd.merge(
    gene_annotation.loc[gene_list].cog.dropna().to_frame(),
    cog_x_category,
    on='cog',
).category.value_counts().to_frame().join(cog_category).head(10))

## What are the phylogenetically co-conserved clusters of genes?

### What about within energy metabolism

In [None]:
strain_list = high_power_strain_list
this_cog_category = 'C'
cog_list = cog_x_category[cog_x_category.category == this_cog_category].cog.unique()
gene_list = list(
    set(idxwhere(gene_meta.cog.isin(cog_list)))
    & set(idxwhere(moderate_hit[strain_list].mean(1) > 0.05))
)

x = strain_depth.loc[gene_list, strain_list]

if len(gene_list) < 2e4:
    cg = sns.clustermap(
        x,
        metric='cosine',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=5,
        xticklabels=1,
        col_linkage=fit.genotype.sel(strain=strain_list).linkage("strain"),
    )
else:
    print("Too many genes for clustermap:", len(gene_list))

#### TCA Cycle genes?

In [None]:
strain_list = high_power_strain_list
this_cog_pathway = 'C'
cog_list = cog_x_category[cog_x_category.category == this_cog_category].cog.unique()
gene_list = list(
    set(idxwhere(gene_meta.cog.isin(cog_list)))
    & set(idxwhere(moderate_hit[strain_list].mean(1) > 0.05))
)
# gene_list = idxwhere(high_confidence_hit[strain_list].mean(1) > 0.05)

x = strain_depth.loc[gene_list, strain_list]

tca_cycle_gene_cluster_list = idxwhere(pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(x, metric='cosine')), index=x.index, columns=x.index).loc['UHGG000026_04491'].sort_values() < 0.1)
print(len(tca_cycle_gene_cluster_list))
gene_meta.join(cog_meta, on='cog', rsuffix='_cog_name').loc[tca_cycle_gene_cluster_list][['gene', 'ec_number', 'cog', 'product', 'description', 'gene_cog_name', 'pathway']]

In [None]:
strain_list = high_power_strain_list
gene_list = tca_cycle_gene_cluster_list

x = strain_depth.loc[gene_list, strain_list]

if len(gene_list) < 2e4:
    cg = sns.clustermap(
        x,
        metric='cosine',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=1,
        xticklabels=1,
        col_linkage=fit.genotype.sel(strain=strain_list).linkage("strain"),
    )
else:
    print("Too many genes for clustermap:", len(gene_list))

In [None]:
yes_clust_x_strains = [22, 85, 7, 8, 57, 34, 39, 67]
no_clust_x_strains = [1, 3, 10, 16, 6, 27, 36, 53]
no_clust_x_strain_samples = idxwhere(sample_to_strain.isin(no_clust_x_strains))
yes_clust_x_strain_samples = idxwhere(sample_to_strain.isin(yes_clust_x_strains))

sp.stats.mannwhitneyu(species_rabund[species_id][no_clust_x_strain_samples], species_rabund[species_id][yes_clust_x_strain_samples])

In [None]:
d = sample_to_strain.to_frame('strain').assign(rabund=species_rabund[species_id]).join(sample_meta).groupby(['subject_id', 'strain']).rabund.mean().reset_index().assign(clust_x_strain=lambda x: x.strain.isin(yes_clust_x_strains))
sns.stripplot('strain', 'rabund', data=d, hue='clust_x_strain', order=yes_clust_x_strains + no_clust_x_strains)
plt.yscale('log')

In [None]:
import lib.stats

In [None]:
lib.stats.mannwhitneyu('clust_x_strain', 'rabund', data=d)

In [None]:
d = sample_to_strain.to_frame('strain').assign(rabund=species_rabund[species_id])


In [None]:
gene_list = tca_cycle_gene_cluster_list

x = (isolate_gene > 0).reindex(gene_list).fillna(False)
plt.hist(x.sum())

In [None]:
d = gene_depth.sel(gene_id=tca_cycle_gene_cluster_list).to_series().unstack('gene_id').groupby(sample_meta.subject_id).sum()
c = species_depth.groupby(sample_meta.subject_id).sum()
c = (c/c.max())**(1/4)

sns.clustermap(
    d + 1e-4,
    row_colors=c.map(mpl.cm.viridis),
    metric='cosine',
    norm=mpl.colors.PowerNorm(1/4),
)


#### Respiration genes?

In [None]:
strain_list = high_power_strain_list
this_cog_pathway = 'C'
cog_list = cog_x_category[cog_x_category.category == this_cog_category].cog.unique()
gene_list = list(
    set(idxwhere(gene_meta.cog.isin(cog_list)))
    & set(idxwhere(moderate_hit[strain_list].mean(1) > 0.05))
)
# gene_list = idxwhere(high_confidence_hit[strain_list].mean(1) > 0.05)

x = strain_depth.loc[gene_list, strain_list]

respiration_gene_cluster_list = idxwhere(pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(x, metric='cosine')), index=x.index, columns=x.index).loc['UHGG144268_03892'].sort_values() < 0.1)
print(len(respiration_gene_cluster_list))
gene_meta.join(cog_meta, on='cog', rsuffix='_cog_name').loc[respiration_gene_cluster_list][['gene', 'ec_number', 'cog', 'product', 'description', 'gene_cog_name', 'pathway']]

In [None]:
strain_list = high_power_strain_list
gene_list = respiration_gene_cluster_list

x = strain_depth.loc[gene_list, strain_list]

if len(gene_list) < 2e4:
    cg = sns.clustermap(
        x,
        metric='cosine',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=5,
        xticklabels=1,
        col_linkage=fit.genotype.sel(strain=strain_list).linkage("strain"),
    )
else:
    print("Too many genes for clustermap:", len(gene_list))

In [None]:
no_clust_x_strains = [22, 85, 7, 8, 57, 34, 39, 67]
yes_clust_x_strains = [1, 3, 10, 16, 6, 27, 36, 53]
no_clust_x_strain_samples = idxwhere(sample_to_strain.isin(no_clust_x_strains))
yes_clust_x_strain_samples = idxwhere(sample_to_strain.isin(yes_clust_x_strains))

sp.stats.mannwhitneyu(species_rabund[species_id][no_clust_x_strain_samples], species_rabund[species_id][yes_clust_x_strain_samples])

In [None]:
gene_list = respiration_gene_cluster_list

x = (isolate_gene > 0).reindex(gene_list).fillna(False)
plt.hist(x.sum())

In [None]:
d = gene_depth.sel(gene_id=respiration_gene_cluster_list).to_series().unstack('gene_id').groupby(sample_meta.subject_id).sum()
c = species_depth.groupby(sample_meta.subject_id).sum()
c = (c/c.max())**(1/4)

sns.clustermap(
    d,
    row_colors=c.map(mpl.cm.viridis),
    # metric='cosine',
    norm=mpl.colors.SymLogNorm(1e-1),
)


### What about within carb metab?

In [None]:
strain_list = high_power_strain_list
this_cog_category = 'G'
cog_list = cog_x_category[cog_x_category.category == this_cog_category].cog.unique()
gene_list = list(
    set(idxwhere(gene_meta.cog.isin(cog_list)))
    & set(idxwhere(moderate_hit[strain_list].mean(1) > 0.05))
)

x = strain_depth.loc[gene_list, strain_list]

if len(gene_list) < 2e4:
    cg = sns.clustermap(
        x,
        metric='cosine',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=13,
        xticklabels=1,
        col_linkage=fit.genotype.sel(strain=strain_list).linkage("strain"),
    )
else:
    print("Too many genes for clustermap:", len(gene_list))

#### Ribose import genes?

In [None]:
strain_list = high_power_strain_list
this_cog_pathway = 'C'
cog_list = cog_x_category[cog_x_category.category == this_cog_category].cog.unique()
gene_list = list(
    set(idxwhere(gene_meta.cog.isin(cog_list)))
    & set(idxwhere(high_confidence_hit[strain_list].mean(1) > 0.05))
)
guide_gene = 'UHGG000489_02378'
# gene_list = idxwhere(high_confidence_hit[strain_list].mean(1) > 0.05)

x = strain_depth.loc[gene_list, strain_list]

_gene_cluster_list = idxwhere(pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(x, metric='cosine')), index=x.index, columns=x.index).loc[guide_gene].sort_values() < 0.1)
print(len(_gene_cluster_list))

ribose_gene_cluster_list = _gene_cluster_list  # FIXME

gene_meta.join(cog_meta, on='cog', rsuffix='_cog_name').loc[_gene_cluster_list][['gene', 'ec_number', 'cog', 'product', 'description', 'gene_cog_name', 'pathway']]

In [None]:
strain_list = high_power_strain_list
gene_list = ribose_gene_cluster_list

x = strain_depth.loc[gene_list, strain_list]

if len(gene_list) < 2e4:
    cg = sns.clustermap(
        x,
        metric='cosine',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=1,
        xticklabels=1,
        col_linkage=fit.genotype.sel(strain=strain_list).linkage("strain"),
    )
else:
    print("Too many genes for clustermap:", len(gene_list))

In [None]:
gene_list = ribose_gene_cluster_list

x = (reference_gene > 0).reindex(gene_list).fillna(False)
plt.hist(x.sum())

In [None]:
gene_list = ribose_gene_cluster_list
d = gene_depth.sel(gene_id=gene_list).to_series().unstack('gene_id').groupby(sample_meta.subject_id).sum()
depth = species_depth.groupby(sample_meta.subject_id).sum()
depth_colors = ((depth / depth.max())**(1/4)).map(mpl.cm.viridis)
diagnosis_colors = subject.loc[d.index].ibd_diagnosis.map({'CD': 'red', 'UC': 'pink', 'nonIBD': 'grey'})

sns.clustermap(
    d,
    row_colors=depth_colors.to_frame('depth').assign(ibd=diagnosis_colors),
    # metric='cosine',
    norm=mpl.colors.PowerNorm(1/4),
)


#### fructose/mannose catabolism

In [None]:
strain_list = high_power_strain_list
this_cog_pathway = 'C'
cog_list = cog_x_category[cog_x_category.category == this_cog_category].cog.unique()
gene_list = list(
    set(idxwhere(gene_meta.cog.isin(cog_list)))
    & set(idxwhere(moderate_hit[strain_list].mean(1) > 0.05))
)
guide_gene = 'UHGG033023_00488'
# gene_list = idxwhere(high_confidence_hit[strain_list].mean(1) > 0.05)

x = strain_depth.loc[gene_list, strain_list]

_gene_cluster_list = idxwhere(
    pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(x, metric='cosine')), index=x.index, columns=x.index)
    .loc[guide_gene].sort_values() < 0.05
)
print(len(_gene_cluster_list))

mannose_gene_cluster_list = _gene_cluster_list  # FIXME

gene_meta.join(cog_meta, on='cog', rsuffix='_cog_name').loc[_gene_cluster_list][['gene', 'ec_number', 'cog', 'product', 'description', 'gene_cog_name', 'pathway']]

In [None]:
strain_list = high_power_strain_list
gene_list = mannose_gene_cluster_list

x = strain_depth.loc[gene_list, strain_list]

if len(gene_list) < 2e4:
    cg = sns.clustermap(
        x,
        metric='cosine',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=1,
        xticklabels=1,
        col_linkage=fit.genotype.sel(strain=strain_list).linkage("strain"),
    )
else:
    print("Too many genes for clustermap:", len(gene_list))

In [None]:
sample_list = strain_to_sample_list[53] + strain_to_sample_list[27]
sf.plot.plot_metagenotype(fit.sel(sample=sample_list, position=position_ss), row_linkage_func=lambda w: w.genotype.linkage('position'))

In [None]:
gene_list = mannose_gene_cluster_list

x = (reference_gene > 0).reindex(gene_list).fillna(False)
y = (isolate_gene > 0).reindex(gene_list).fillna(False)
z = high_confidence_hit[high_power_strain_list].reindex(gene_list).fillna(False)
a = moderate_hit[high_power_strain_list].reindex(gene_list).fillna(False)


bins = np.arange(len(gene_list) + 3)
plt.hist(x.sum(), bins=bins)
plt.hist(y.sum(), bins=bins)
plt.hist(z.sum(), bins=bins)
plt.hist(a.sum(), bins=bins)

plt.yscale('log')

In [None]:
gene_list = mannose_gene_cluster_list
d = gene_depth.sel(gene_id=gene_list, sample=sample_list).to_series().unstack('gene_id')

sns.clustermap(
    d,
    # metric='cosine',
    norm=mpl.colors.PowerNorm(1/4),
)


In [None]:
gene_list = mannose_gene_cluster_list
d = gene_depth.sel(gene_id=gene_list).to_series().unstack('gene_id').groupby(sample_meta.subject_id).sum()
depth = species_depth.groupby(sample_meta.subject_id).sum()
depth_colors = ((depth / depth.max())**(1/4)).map(mpl.cm.viridis)
diagnosis_colors = subject.loc[d.index].ibd_diagnosis.map({'CD': 'red', 'UC': 'pink', 'nonIBD': 'grey'})

sns.clustermap(
    d,
    row_colors=depth_colors.to_frame('depth').assign(ibd=diagnosis_colors),
    # metric='cosine',
    norm=mpl.colors.PowerNorm(1/4),
)


### What about within cell envelope?

In [None]:
strain_list = high_power_strain_list
this_cog_category = 'M'
cog_list = cog_x_category[cog_x_category.category == this_cog_category].cog.unique()
gene_list = list(
    set(idxwhere(gene_meta.cog.isin(cog_list)))
    & set(idxwhere(high_confidence_hit[strain_list].mean(1) > 0.1))
)

x = strain_depth.loc[gene_list, strain_list]
# x = high_confidence_hit.loc[gene_list, strain_list]


if len(gene_list) < 2e4:
    cg = sns.clustermap(
        x,
        metric='cosine',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=30,
        xticklabels=1,
        col_linkage=fit.genotype.sel(strain=strain_list).linkage("strain"),
    )
else:
    print("Too many genes for clustermap:", len(gene_list))

#### O-antigen genes?

In [None]:
strain_list = high_power_strain_list
this_cog_pathway = 'C'
cog_list = cog_x_category[cog_x_category.category == this_cog_category].cog.unique()
gene_list = list(
    set(idxwhere(gene_meta.cog.isin(cog_list)))
    & set(idxwhere(high_confidence_hit[strain_list].mean(1) > 0.05))
)
guide_gene = 'UHGG004518_00665'
# gene_list = idxwhere(high_confidence_hit[strain_list].mean(1) > 0.05)

x = strain_depth.loc[gene_list, strain_list]

_gene_cluster_list = idxwhere(
    pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(x, metric='cosine')), index=x.index, columns=x.index)
    .loc[guide_gene].sort_values() < 0.1
)
print(len(_gene_cluster_list))

x_gene_cluster_list = _gene_cluster_list  # FIXME

gene_meta.join(cog_meta, on='cog', rsuffix='_cog_name').loc[_gene_cluster_list][['gene', 'ec_number', 'cog', 'product', 'description', 'gene_cog_name', 'pathway']]

In [None]:
strain_list = high_power_strain_list
gene_list = x_gene_cluster_list

x = strain_depth.loc[gene_list, strain_list]

if len(gene_list) < 2e4:
    cg = sns.clustermap(
        x,
        metric='cosine',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=1,
        xticklabels=1,
        col_linkage=fit.genotype.sel(strain=strain_list).linkage("strain"),
    )
else:
    print("Too many genes for clustermap:", len(gene_list))

In [None]:
from itertools import chain

In [None]:
strain_list = [36, 67, 39]
sample_list = list(chain.from_iterable(strain_to_sample_list[s] for s in strain_list))
sf.plot.plot_metagenotype(fit.sel(sample=sample_list, position=position_ss), row_linkage_func=lambda w: w.genotype.linkage('position'))

In [None]:
gene_list = x_gene_cluster_list

x = (reference_gene > 0).reindex(gene_list).fillna(False)
y = (isolate_gene > 0).reindex(gene_list).fillna(False)
z = high_confidence_hit.reindex(gene_list).fillna(False)

bins = np.arange(len(gene_list) + 3)
plt.hist(x.sum(), bins=bins)
plt.hist(y.sum(), bins=bins)
plt.hist(z.sum(), bins=bins)

plt.yscale('log')

In [None]:
gene_list = x_gene_cluster_list
d = gene_depth.sel(gene_id=gene_list, sample=sample_list).to_series().unstack('gene_id')

sns.clustermap(
    d,
    # metric='cosine',
    # norm=mpl.colors.PowerNorm(1/4),
)


In [None]:
gene_list = x_gene_cluster_list
d = gene_depth.sel(gene_id=gene_list).to_series().unstack('gene_id').groupby(sample_meta.subject_id).sum()
depth = species_depth.groupby(sample_meta.subject_id).sum()
depth_colors = ((depth / depth.max())**(1/4)).map(mpl.cm.viridis)
diagnosis_colors = subject.loc[d.index].ibd_diagnosis.map({'CD': 'red', 'UC': 'pink', 'nonIBD': 'grey'})

sns.clustermap(
    d + 1e-2,
    row_colors=depth_colors.to_frame('depth').assign(ibd=diagnosis_colors),
    # metric='cosine',
    norm=mpl.colors.LogNorm(),
)


In [None]:
sample_meta.assign(strain_rabund=fit.community.data.sel(strain=67).to_series()).groupby('subject_id').strain_rabund.mean().sort_values(ascending=False)

In [None]:
sample_list = idxwhere((fit.community.data.sel(strain=strain_list) > 0.05).any("strain").to_series())
sf.plot.plot_community(fit.sel(sample=sample_list).drop_low_abundance_strains(0.05), row_linkage_func=lambda w: w.genotype.linkage("strain"))

In [None]:
gene_list = x_gene_cluster_list
d = gene_depth.sel(gene_id=gene_list, sample=sample_list).to_series().unstack('gene_id')

sns.clustermap(
    d,
    norm=mpl.colors.PowerNorm(1/4),
)


#### O-antigen genes

In [None]:
strain_list = high_power_strain_list
this_cog_pathway = 'C'
cog_list = cog_x_category[cog_x_category.category == this_cog_category].cog.unique()
gene_list = list(
    set(idxwhere(gene_meta.cog.isin(cog_list)))
    & set(idxwhere(high_confidence_hit[strain_list].mean(1) > 0.05))
)
guide_gene = 'UHGG004518_00665'
# gene_list = idxwhere(high_confidence_hit[strain_list].mean(1) > 0.05)

x = strain_depth.loc[gene_list, strain_list]

_gene_cluster_list = idxwhere(
    pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(x, metric='cosine')), index=x.index, columns=x.index)
    .loc[guide_gene].sort_values() < 0.15
)
print(len(_gene_cluster_list))

x_gene_cluster_list = _gene_cluster_list  # FIXME

gene_meta.join(cog_meta, on='cog', rsuffix='_cog_name').loc[_gene_cluster_list][['gene', 'ec_number', 'cog', 'product', 'description', 'gene_cog_name', 'pathway']]

In [None]:
strain_list = high_power_strain_list
gene_list = x_gene_cluster_list

x = strain_depth.loc[gene_list, strain_list]

if len(gene_list) < 2e4:
    cg = sns.clustermap(
        x,
        metric='cosine',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=1,
        xticklabels=1,
        col_linkage=fit.genotype.sel(strain=strain_list).linkage("strain"),
    )
else:
    print("Too many genes for clustermap:", len(gene_list))

In [None]:
strain_list = [36, 67, 81]
sample_list = list(chain.from_iterable(strain_to_sample_list[s] for s in strain_list))
sf.plot.plot_metagenotype(fit.sel(sample=sample_list, position=position_ss), row_linkage_func=lambda w: w.genotype.linkage('position'))

In [None]:
gene_list = x_gene_cluster_list

x = (reference_gene > 0).reindex(gene_list).fillna(False)
y = (isolate_gene > 0).reindex(gene_list).fillna(False)
z = high_confidence_hit.reindex(gene_list).fillna(False)
a = high_confidence_hit[high_power_strain_list].reindex(gene_list).fillna(False)


bins = np.arange(len(gene_list) + 3)
plt.hist(x.sum(), bins=bins)
plt.hist(y.sum(), bins=bins)
plt.hist(z.sum(), bins=bins)
plt.hist(a.sum(), bins=bins)


plt.yscale('log')

In [None]:
gene_list = x_gene_cluster_list
d = gene_depth.sel(gene_id=gene_list, sample=sample_list).to_series().unstack('gene_id')

sns.clustermap(
    d,
    # metric='cosine',
    norm=mpl.colors.PowerNorm(1/4),
)


In [None]:
gene_list = x_gene_cluster_list
d = gene_depth.sel(gene_id=gene_list).to_series().unstack('gene_id').groupby(sample_meta.subject_id).sum()
depth = species_depth.groupby(sample_meta.subject_id).sum()
depth_colors = ((depth / depth.max())**(1/4)).map(mpl.cm.viridis)
diagnosis_colors = subject.loc[d.index].ibd_diagnosis.map({'CD': 'red', 'UC': 'pink', 'nonIBD': 'grey'})

sns.clustermap(
    d,
    row_colors=depth_colors.to_frame('depth').assign(ibd=diagnosis_colors),
    # metric='cosine',
    norm=mpl.colors.PowerNorm(1/4),
)


In [None]:
sf.plot.plot_community(fit.sel(sample=sample_list).drop_low_abundance_strains(0.05), row_linkage_func=lambda w: w.genotype.linkage("strain"))

In [None]:
gene_list = x_gene_cluster_list
d = gene_depth.sel(gene_id=gene_list, sample=sample_list).to_series().unstack('gene_id')

sns.clustermap(
    d,
    norm=mpl.colors.PowerNorm(1/4),
)


### What about within defense?

In [None]:
strain_list = high_power_strain_list
this_cog_category = 'V'
cog_list = cog_x_category[cog_x_category.category == this_cog_category].cog.unique()
gene_list = list(
    set(idxwhere(gene_meta.cog.isin(cog_list)))
    & set(idxwhere((high_confidence_hit[strain_list].mean(1) > 0.05) & (high_confidence_not_hit[strain_list].mean(1) > 0.2)))
)

x = strain_depth.loc[gene_list, strain_list]
# x = high_confidence_hit.loc[gene_list, strain_list]


if len(gene_list) < 2e4:
    cg = sns.clustermap(
        x,
        metric='cosine',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=2,
        xticklabels=1,
        col_linkage=fit.genotype.sel(strain=strain_list).linkage("strain"),
        # figsize=(7, 20),
    )
else:
    print("Too many genes for clustermap:", len(gene_list))

#### X genes?

In [None]:
strain_list = high_power_strain_list
this_cog_pathway = 'C'
cog_list = cog_x_category[cog_x_category.category == this_cog_category].cog.unique()
gene_list = list(
    set(idxwhere(gene_meta.cog.isin(cog_list)))
    & set(idxwhere(high_confidence_hit[strain_list].mean(1) > 0.05))
)
guide_gene = 'UHGG153923_03568'
# gene_list = idxwhere(high_confidence_hit[strain_list].mean(1) > 0.05)

x = strain_depth.loc[gene_list, strain_list]

_gene_cluster_list = idxwhere(
    pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(x, metric='cosine')), index=x.index, columns=x.index)
    .loc[guide_gene].sort_values() < 0.1
)
print(len(_gene_cluster_list))

x_gene_cluster_list = _gene_cluster_list  # FIXME

gene_meta.join(cog_meta, on='cog', rsuffix='_cog_name').loc[_gene_cluster_list][['gene', 'ec_number', 'cog', 'product', 'description', 'gene_cog_name', 'pathway']]

In [None]:
strain_list = high_power_strain_list
gene_list = x_gene_cluster_list

x = strain_depth.loc[gene_list, strain_list]

if len(gene_list) < 2e4:
    cg = sns.clustermap(
        x,
        metric='cosine',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=1,
        xticklabels=1,
        col_linkage=fit.genotype.sel(strain=strain_list).linkage("strain"),
    )
else:
    print("Too many genes for clustermap:", len(gene_list))

In [None]:
strain_list = [16, 6, 27]
sample_list = list(chain.from_iterable(strain_to_sample_list[s] for s in strain_list))
sf.plot.plot_metagenotype(
    fit.sel(sample=sample_list, position=position_ss),
    row_linkage_func=lambda w: w.genotype.linkage('position'),
    col_linkage_func=lambda w: w.metagenotype.linkage(),
)

In [None]:
sf.plot.plot_community(
    fit.sel(sample=sample_list, position=position_ss).drop_low_abundance_strains(0.05),
    row_linkage_func=lambda w: w.genotype.linkage('strain'),
    col_linkage_func=lambda w: w.metagenotype.linkage(),
)

In [None]:
gene_list = x_gene_cluster_list

x = (reference_gene > 0).reindex(gene_list).fillna(False)
y = (isolate_gene > 0).reindex(gene_list).fillna(False)
z = high_confidence_hit.reindex(gene_list).fillna(False)
a = high_confidence_hit[high_power_strain_list].reindex(gene_list).fillna(False)


bins = np.arange(len(gene_list) + 3)
plt.hist(x.sum(), bins=bins, alpha=0.5)
plt.hist(y.sum(), bins=bins, alpha=0.5)
plt.hist(z.sum(), bins=bins, alpha=0.5)
plt.hist(a.sum(), bins=bins, alpha=0.5)


plt.yscale('log')

In [None]:
gene_list = x_gene_cluster_list
d = gene_depth.sel(gene_id=gene_list, sample=sample_list).to_series().unstack('gene_id')

sns.clustermap(
    d,
    # metric='cosine',
    norm=mpl.colors.PowerNorm(1/4),
)


In [None]:
gene_list = x_gene_cluster_list
d = gene_depth.sel(gene_id=gene_list).to_series().unstack('gene_id').groupby(sample_meta.subject_id).sum()
depth = species_depth.groupby(sample_meta.subject_id).sum()
depth_colors = ((depth / depth.max())**(1/4)).map(mpl.cm.viridis)
diagnosis_colors = subject.loc[d.index].ibd_diagnosis.map({'CD': 'red', 'UC': 'pink', 'nonIBD': 'grey'})

sns.clustermap(
    d + 1e-5,
    row_colors=depth_colors.to_frame('depth').assign(ibd=diagnosis_colors),
    metric='cosine',
    norm=mpl.colors.PowerNorm(1/4),
)


# Do strain-private genes show up in samples with that strain at less than 100%?

In [None]:
strain_meta.loc[8]

In [None]:
strain = 8

strain_private_genes = idxwhere(high_confidence_hit[strain] & (moderate_hit.sum(1) == 1))
gene_meta.join(cog_meta, on='cog', rsuffix='_cog_name').loc[strain_private_genes][['gene', 'ec_number', 'cog', 'product', 'description', 'gene_cog_name', 'pathway']]

In [None]:
from lib.pandas_util import align_indexes

x = species_depth
y = fit.community.data.sel(strain=strain).to_series().reindex(x.index, fill_value=0)
r = fit.community.data.sel(strain=high_power_strain_list).sum("strain").to_series()
z = x * y
d = gene_depth.sel(gene_id=strain_private_genes).to_series().unstack('gene_id')
b = ((d.sum(1) > 1e-3) | (y > 1e-1))[lambda x: x]
x, y, r, z, d, b = align_indexes(x, y, r, z, d, b)

sns.clustermap(
    d + 1e-2,
    metric='cosine',
    norm=mpl.colors.LogNorm(),
    row_colors=pd.DataFrame(dict(
        y=y.map(mpl.cm.viridis),
        r=r.map(mpl.cm.viridis),
        x=np.log(x + 1e-2).pipe(lambda x: (x - x.min()) / (x.max() - x.min())).map(mpl.cm.viridis),
        z=np.log(z + 1e-2).pipe(lambda x: (x - x.min()) / (x.max() - x.min())).map(mpl.cm.viridis),
    ))
)

# Can we quantify the degree of genotype co-similarity between genes and SNPs?