# Preamble

In [None]:
%load_ext autoreload

In [None]:
import os
os.chdir('..')
os.path.realpath(os.path.curdir)

In [None]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
from lib.pandas_util import idxwhere, align_indexes
import matplotlib as mpl
import lib.plot
import statsmodels as sm
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm
import subprocess
from tempfile import mkstemp
import time

In [None]:
import sfacts as sf

In [None]:
all_species_gtpro_depth = pd.read_table(f'data/hmp2.a.r.proc.gtpro.species_depth.tsv', dtype=dict(sample=str, species_id=str, depth=float), index_col=['sample', 'species_id']).squeeze().unstack('species_id', fill_value=0)
all_species_gtpro_depth

In [None]:
species_taxonomy = pd.read_table('ref/gtpro/species_taxonomy_ext.tsv', names=['genome_id', 'species_id', 'taxonomy_string']).assign(species_id=lambda x: x.species_id.astype(str)).set_index('species_id')[['taxonomy_string']].assign(taxonomy_split=lambda x: x.taxonomy_string.str.split(';'))

for level_name, level_number in [('p__', 1), ('c__', 2), ('o__', 3), ('f__', 4), ('g__', 5), ('s__', 6)]:
    species_taxonomy = species_taxonomy.assign(**{level_name: species_taxonomy.taxonomy_split.apply(lambda x: x[level_number])}) 
species_taxonomy = species_taxonomy.drop(columns=['taxonomy_split']).loc[all_species_gtpro_depth.columns]

In [None]:
(all_species_gtpro_depth > 0.5).sum().sort_values(ascending=False)

In [None]:
species_id = '102506'

In [None]:
species_taxonomy.loc[species_id]

In [None]:
fit = sf.World.load(f'data_temp/sp-{species_id}.hmp2.a.r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts26-s75-seed0.world.nc').drop_low_abundance_strains(0.05)
fit.sizes

In [None]:
# sf.plot.plot_community(w25, col_linkage_func=lambda w: w.metagenotype.linkage())
sf.plot.plot_community(fit, col_linkage_func=lambda w: w.metagenotype.linkage())

In [None]:
strain_frac = pd.read_table(f'data_temp/sp-{species_id}.hmp2.a.r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts26-s75-seed0.comm.tsv', index_col=['sample', 'strain']).squeeze().unstack(fill_value=0)
strain_frac.shape

In [None]:
species_gene_depth = pd.read_table(f'data_temp/sp-{species_id}.hmp2.a.r.proc.midas_gene75.species_depth.tsv', names=['sample', 'depth'], index_col='sample').squeeze()
gtpro_depth = pd.read_table(f'data/hmp2.a.r.proc.gtpro.species_depth.tsv', dtype=dict(sample=str, species_id=str, depth=float), index_col=['sample', 'species_id']).squeeze().unstack('species_id', fill_value=0)[species_id]

In [None]:
d = pd.DataFrame(dict(gene=species_gene_depth, gtpro=gtpro_depth))

plt.scatter('gtpro', 'gene', data=d, s=3, alpha=0.3)
plt.plot([0, 1e2], [0, 1e2])
plt.yscale('symlog', linthresh=1e-4)
plt.xscale('symlog', linthresh=1e-4)

In [None]:
d = pd.DataFrame(dict(gene=species_gene_depth, gtpro=gtpro_depth))

_trnsf = np.cbrt

plt.scatter('gtpro', 'gene', data=_trnsf(d), s=3, alpha=0.3)
plt.plot([_trnsf(0), _trnsf(1e2)], [_trnsf(0), _trnsf(1e2)])
# plt.yscale('symlog', linthresh=1e-4)
# plt.xscale('symlog', linthresh=1e-4)

In [None]:
species_corr = pd.read_table(f'data_temp/sp-{species_id}.hmp2.a.r.proc.midas_gene75.species_correlation.tsv', names=['sample', 'correlation'], index_col='sample').squeeze()
species_depth_ratio = pd.read_table(f'data_temp/sp-{species_id}.hmp2.a.r.proc.midas_gene75.species_depth_ratio.tsv', names=['sample', 'depth_ratio'], index_col='sample').squeeze()

In [None]:
plt.hist(np.log10(species_depth_ratio[species_corr > 0.95]), bins=100)
None

In [None]:
np.log10(species_depth_ratio[species_corr > 0.95]).std()

In [None]:
strain_corr_75 = pd.read_table(f'data_temp/sp-{species_id}.hmp2.a.r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts26-s75-seed0.midas_gene75.strain_correlation.tsv', index_col=['gene_id', 'strain']).squeeze().unstack(fill_value=0)
strain_depth_75 = pd.read_table(f'data_temp/sp-{species_id}.hmp2.a.r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts26-s75-seed0.midas_gene75.strain_depth_ratio.tsv', index_col=['gene_id', 'strain']).squeeze().unstack()
strain_corr_75, strain_depth_75 = align_indexes(*align_indexes(strain_corr_75, strain_depth_75), axis="columns")

In [None]:
gene_meta = pd.read_table(f'ref_temp/midasdb_uhgg/pangenomes/{species_id}/cluster_info.txt').set_index('centroid_99', drop=False).rename_axis(index='gene_id')

In [None]:
gene_annotations = pd.read_table('ref_temp/midasdb_uhgg.102506.centroid_75.tsv', index_col='locus_tag')
cog_meta = pd.read_table(
    'ref/cog-20.meta.tsv',
    names=['cog', 'categories', 'description', 'gene', 'pathway', '_1', '_2'],
    index_col=['cog']
)
cog_meta

In [None]:
cog_category = pd.read_table('ref/cog-20.categories.tsv', names=['category', 'description'], index_col='category')

In [None]:
gene_depth_75 = xr.load_dataarray(f'data_temp/sp-{species_id}.hmp2.a.r.proc.midas_gene75.depth.nc')

In [None]:
species_genes = idxwhere(species_corr > 0.95)

In [None]:
strain_threshold = pd.read_table(
    f'data_temp/sp-{species_id}.hmp2.a.r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts26-s75-seed0.midas_gene75.strain_correlation_threshold.tsv',
    names=['strain', 'threshold'], index_col='strain',
).squeeze()
# strain_threshold.to_frame().assign(ratio_log_depth_std=ratio_log_depth_std).sort_values('ratio_log_depth_std').head(20)
strain_threshold.sort_values(ascending=False)

In [None]:
strain = 15

In [None]:
d = strain_depth_75[
    strain_corr_75[strain]
    # > 0.8
    > strain_threshold[strain]
]

sns.clustermap(
    d,
    metric='cosine',
    norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10)
)

In [None]:
d = strain_corr_75[
    strain_corr_75[strain]
    # > 0.8
    > strain_threshold[strain]
]

sns.clustermap(
    d,
    metric='cosine',
)

In [None]:
d = gene_depth_75.sel(
    gene_id=idxwhere(
        strain_corr_75[strain]
        > strain_threshold[strain]
        # > 0.8
    ),
    sample=idxwhere(strain_frac[strain] > 0.95),
).to_series().unstack()
print(d.shape)

sns.clustermap(d, metric='cosine', norm=mpl.colors.SymLogNorm(linthresh=0.1))

In [None]:
sample_list = idxwhere(strain_frac[strain] > 0.95)
mgen_allele_totals = fit.metagenotype.sel(sample=sample_list).sum("sample")
position_order = (
    (
        mgen_allele_totals.sel(allele="alt")
        / mgen_allele_totals.sum("allele")
    )
    .to_series()
    .dropna()
    [lambda x: (0 < x) & (1 > x)]
    .sort_values()
    .index
)

sf.plot.plot_metagenotype(
    fit.sel(
        sample=sample_list,
        position=position_order,
    ),
    # row_cluster=False,
)

In [None]:
species_corr_thresh = 0.95
strain_corr_quantile = 0.95


d0 = pd.DataFrame(dict(
    strain_max_diss=1 - strain_corr_75[strain].groupby(gene_meta.centroid_75).max(),
    species_diss=1 - species_corr,
    depth_ratio=strain_depth_75[strain].groupby(gene_meta.centroid_75).sum(),
))

fig, ax = plt.subplots(figsize=(12, 10))

ax.scatter(
    'strain_max_diss', 'species_diss', c='depth_ratio', data=d0,
    s=2, alpha=0.7, norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=1e-1, vmax=10),
)
ax.set_xscale('log')
ax.set_yscale('log')
ax.axhline(1 - species_corr_thresh, lw=1, linestyle=':', c='k')

ax1 = ax.twinx()
d1 = d0[d0['species_diss'] < (1 - species_corr_thresh)]
sns.kdeplot(d1.strain_max_diss, ax=ax1)
ax1.set_ylim(ymax=ax1.get_ylim()[1] * 5)
q = 1 - d1.strain_max_diss.quantile(strain_corr_quantile)
ax1.axvline(1 - q, lw=1, linestyle='--', c='k')

print(
    q,
    (d0.strain_max_diss < 1 - q).sum(),
    d0[d0.strain_max_diss < 1 - q].depth_ratio.sum(),
    (strain_depth_75[strain] * (strain_corr_75[strain] > q)).sum()
)

In [None]:
gene_list = idxwhere((strain_corr_75 > strain_threshold).any(1))
d = strain_depth_75.loc[gene_list]

sns.clustermap(
    d,
    metric='cosine',
    norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.2, vmax=5)
)

In [None]:
gene_list = idxwhere((strain_corr_75 > strain_threshold).any(1))
d = strain_corr_75.loc[gene_list]

sns.clustermap(
    d,
    metric='cosine',
)

In [None]:
depth_ratio_bound = 5

gene_list = idxwhere((strain_corr_75 > strain_threshold).any(1))
depth_hit = (strain_depth_75.loc[gene_list] < depth_ratio_bound) & (strain_depth_75.loc[gene_list] > 1 / depth_ratio_bound)
high_depth = strain_depth_75 > depth_ratio_bound
corr_hit = strain_corr_75.loc[gene_list] > strain_threshold
white = depth_hit & corr_hit
grey = depth_hit ^ corr_hit
black = ~depth_hit & ~corr_hit
grey_high_depth = corr_hit & high_depth

bins = np.linspace(0, 1, num=21)

plt.hist(white.mean(1), bins=bins, label='white all')
plt.hist(white.mean(1)[species_corr > 0.95], bins=bins, label='white species genes')
plt.hist(grey.mean(1)[species_corr > 0.95], bins=bins, label='grey species genes')
plt.hist(grey_high_depth.mean(1)[species_corr > 0.95], bins=bins, label='grey high depth species genes')
plt.hist(black.mean(1)[species_corr > 0.95], bins=bins, label='black species genes')

plt.legend()
None

In [None]:
genes_list = idxwhere(white.mean(1) > 0.9)
d = strain_depth_75.loc[genes_list]

sns.clustermap(
    d,
    metric='cosine',
    norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.2, vmax=5)
)

print(
    gene_annotations
    .loc[genes_list]
    .COG.to_frame()
    .join(cog_meta, on='COG')
    .pathway
    .value_counts()
    .sort_values(ascending=False)
    .head(10)
)

print(
    gene_annotations
    .loc[genes_list]
    .COG.to_frame()
    .join(cog_meta, on='COG')
    .categories
    .value_counts()
    .sort_values(ascending=False)
    .head(10)
    .to_frame()
    .join(cog_category)
)

In [None]:
genes_list = idxwhere((white.mean(1) > 0.05) & (black.mean(1) > 0.8))
d = strain_depth_75.loc[genes_list]

sns.clustermap(
    d,
    metric='cosine',
    norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.2, vmax=5)
)

print(
    gene_annotations
    .loc[genes_list]
    .COG.to_frame()
    .join(cog_meta, on='COG')
    .pathway
    .value_counts()
    .sort_values(ascending=False)
    .head(10)
)

print(
    gene_annotations
    .loc[genes_list]
    .COG.to_frame()
    .join(cog_meta, on='COG')
    .categories
    .value_counts()
    .sort_values(ascending=False)
    .head(10)
    .to_frame()
    .join(cog_category)
)

In [None]:
genes_list = idxwhere((white.mean(1) > 0.2) & (black.mean(1) > 0.2))
d = strain_depth_75.loc[genes_list]

sns.clustermap(
    d,
    metric='cosine',
    norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.2, vmax=5)
)

print(
    gene_annotations
    .loc[genes_list]
    .COG.to_frame()
    .join(cog_meta, on='COG')
    .pathway
    .value_counts()
    .sort_values(ascending=False)
    .head(10)
)

print(
    gene_annotations
    .loc[genes_list]
    .COG.to_frame()
    .join(cog_meta, on='COG')
    .categories
    .value_counts()
    .sort_values(ascending=False)
    .head(10)
    .to_frame()
    .join(cog_category)
)

In [None]:
y = species_corr
x = white.sum(1)
x, y = align_indexes(x, y)

plt.scatter(x, y, alpha=0.1, s=1)

In [None]:
d0 = gene_depth_75.sel(
    gene_id=idxwhere(species_corr > 0.95),
).to_series().unstack()

x = gtpro_depth
y = (d0.std() / d0.mean()).sort_values(ascending=False)

d1 = pd.DataFrame(dict(x=x, y=y))

plt.scatter('x', 'y', data=d1, s=1)
plt.yscale('log')
plt.xscale('log')

# sns.clustermap(
#     d,
#     metric='cosine',
#     norm=mpl.colors.SymLogNorm(linthresh=0.1),
# )

number_genes_in_depth_range = ((strain_depth_75 > 0.2) & (strain_depth_75 < 5)).sum()
number_genes_in_depth_range

In [None]:
gene_meta = pd.read_table(f'ref_temp/midasdb_uhgg/pangenomes/{species_id}/cluster_info.txt').set_index('centroid_99', drop=False).rename_axis(index='gene_id')

In [None]:
thresh = 5

genes_in_depth_range = (strain_depth_75 < thresh) & (strain_depth_75 > 1 / thresh)
ratio_genes_in_depth_range = (
    (genes_in_depth_range & (strain_corr_75 > strain_meta.strain_selection_threshold)).sum()
    / genes_in_depth_range.sum()
)

ratio_genes_in_depth_range

In [None]:
strain_meta.number_strain_agg_hit

In [None]:
thresh = 5

genes_in_depth_range = (strain_depth_75 < thresh) & (strain_depth_75 > 1 / thresh)
genes_in_depth_range_and_hit = (genes_in_depth_range & (strain_corr_75 > strain_meta.strain_selection_threshold))

ratio_genes_in_depth_range = genes_in_depth_range_and_hit.sum() / genes_in_depth_range.sum()
frac_genes_in_depth_range = genes_in_depth_range_and_hit.sum() / strain_meta.number_strain_agg_hit

frac_genes_in_depth_range

In [None]:
gene_depth_99 = xr.load_dataarray(f'data_temp/sp-{species_id}.hmp2.a.r.proc.midas_gene99.depth.nc')
gene_depth_95 = xr.load_dataarray(f'data_temp/sp-{species_id}.hmp2.a.r.proc.midas_gene95.depth.nc')
gene_depth_75 = xr.load_dataarray(f'data_temp/sp-{species_id}.hmp2.a.r.proc.midas_gene75.depth.nc')

In [None]:
strain_corr_99 = pd.read_table(f'data_temp/sp-{species_id}.hmp2.a.r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts26-s75-seed0.midas_gene99.strain_correlation.tsv', index_col=['gene_id', 'strain']).squeeze().unstack(fill_value=0)
strain_depth_99 = pd.read_table(f'data_temp/sp-{species_id}.hmp2.a.r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts26-s75-seed0.midas_gene99.strain_depth_ratio.tsv', index_col=['gene_id', 'strain']).squeeze().unstack()
strain_corr_99, strain_depth_99 = align_indexes(*align_indexes(strain_corr_99, strain_depth_99), axis="columns")

In [None]:
strain_corr_95 = pd.read_table(f'data_temp/sp-{species_id}.hmp2.a.r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts26-s75-seed0.midas_gene95.strain_correlation.tsv', index_col=['gene_id', 'strain']).squeeze().unstack(fill_value=0)
strain_depth_95 = pd.read_table(f'data_temp/sp-{species_id}.hmp2.a.r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts26-s75-seed0.midas_gene95.strain_depth_ratio.tsv', index_col=['gene_id', 'strain']).squeeze().unstack()
strain_corr_95, strain_depth_95 = align_indexes(*align_indexes(strain_corr_95, strain_depth_95), axis="columns")

In [None]:
strain_corr_90 = pd.read_table(f'data_temp/sp-{species_id}.hmp2.a.r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts26-s75-seed0.midas_gene90.strain_correlation.tsv', index_col=['gene_id', 'strain']).squeeze().unstack(fill_value=0)
strain_depth_90 = pd.read_table(f'data_temp/sp-{species_id}.hmp2.a.r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts26-s75-seed0.midas_gene90.strain_depth_ratio.tsv', index_col=['gene_id', 'strain']).squeeze().unstack()
strain_corr_90, strain_depth_90 = align_indexes(*align_indexes(strain_corr_90, strain_depth_90), axis="columns")

In [None]:
(strain_frac > 0.95).sum().sort_values(ascending=False).head(20)

In [None]:
# Snakefile
species='102506'
stemA='hmp2.a.r.proc'
stemB='filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts26-s75-seed0'
centroid='90'

import pandas as pd
from scipy.stats import trim_mean
from lib.pandas_util import align_indexes, idxwhere

# Script Args
outpath = f"data_temp/sp-{species}.{stemA}.gtpro.{stemB}.midas_gene{centroid}.strain_correlation_threshold.tsv"
species_corr_path=f"data_temp/sp-{species}.{stemA}.midas_gene75.species_correlation.tsv"
strain_corr_path=f"data_temp/sp-{species}.{stemA}.gtpro.{stemB}.midas_gene{centroid}.strain_correlation.tsv"
gene_meta_path=f"ref_temp/midasdb_uhgg/pangenomes/{species}/cluster_info.txt"  # TODO: Add to recipe
strain_depth_path=f"data_temp/sp-{species}.{stemA}.gtpro.{stemB}.midas_gene{centroid}.strain_depth_ratio.tsv"  # Only used for depth profiling
aggregate_genes_by="centroid_75"
species_corr_threshold=float(0.95)
strain_corr_quantile=float(0.95)
trim_frac=float(0.2)  # Only used for calculating mean species-gene depth




species_corr_agg = pd.read_table(species_corr_path, names=['gene_id', 'correlation'], index_col='gene_id').squeeze()
strain_corr = pd.read_table(strain_corr_path, index_col=['gene_id', 'strain']).squeeze().unstack(fill_value=0)
gene_meta = pd.read_table(gene_meta_path).set_index('centroid_99', drop=False).rename_axis(index='gene_id')
strain_depth = pd.read_table(strain_depth_path, index_col=['gene_id', 'strain']).squeeze().unstack()

# Align Data
strain_corr, strain_depth = align_indexes(*align_indexes(strain_corr, strain_depth), axis='columns')

strain_corr_agg = strain_corr.groupby(gene_meta[aggregate_genes_by]).max()
strain_depth_agg = strain_depth.groupby(gene_meta[aggregate_genes_by]).sum()
species_corr_agg, strain_corr_agg, strain_depth_agg = align_indexes(species_corr_agg, strain_corr_agg, strain_depth_agg)


# Calculate the strain correlation threshold for each strain at which strain_corr_quantile
# of the species genes (defined as those passing the species_corr_threshold)
# are also assigned to the strain.
species_agg_hit = species_corr_agg > species_corr_threshold
number_species_agg_hit = species_agg_hit.sum()
strain_selection_threshold = strain_corr_agg.reindex(idxwhere(species_agg_hit)).quantile(1 - strain_corr_quantile)

# Stats on aggregated gene hits (aggs where best hit is over threshold)
strain_agg_hit = strain_corr_agg.gt(strain_selection_threshold)
number_strain_agg_hit = strain_agg_hit.sum()
total_depth_ratio_strain_agg_hit = (strain_depth_agg * strain_agg_hit).sum()

# Stats on aggregated gene hits (aggs where best hit is over threshold) that were also species hits
strain_species_agg_hit = (agg_hit.T & species_agg_hit).T
number_strain_species_agg_hit = strain_species_agg_hit.sum()
total_depth_ratio_species_agg_hit = (strain_depth_agg.T * species_agg_hit).T.sum()
total_depth_ratio_strain_species_agg_hit = (strain_depth_agg * strain_species_agg_hit).sum()
mean_depth_ratio_strain_species_agg_hit = (strain_depth_agg[strain_species_agg_hit]).apply(lambda x: trim_mean(x.dropna(), proportiontocut=trim_frac))
mean_depth_ratio_species_agg_hit = (strain_depth_agg[species_agg_hit]).apply(lambda x: trim_mean(x.dropna(), proportiontocut=trim_frac))
# TODO: Calculate ratio of total depth ratios between "known" species aggs and the strain-specific aggs.
# frac_total_depth_ratio_species_agg_hit = total_depth_ratio_strain_species_agg_hit / total_depth_ratio_species_agg_hit

# Stats on gene hits (not aggregated)
strain_gene_hit = strain_corr.gt(strain_selection_threshold)

agg_depth_ratio_strain_gene_hit = strain_depth[strain_gene_hit].groupby(gene_meta[aggregate_genes_by]).sum().reindex(species_corr_agg.index).fillna(0)
agg_tally_strain_gene_hit = strain_depth[strain_gene_hit].notna().groupby(gene_meta[aggregate_genes_by]).sum().reindex(species_corr_agg.index).fillna(0).sum()
number_strain_gene_hit = agg_tally_strain_gene_hit.sum()
total_depth_ratio_strain_gene_hit = agg_depth_ratio_strain_gene_hit.sum()


# Stats on gene hits aggregated after selection
number_strain_species_gene_hit = agg_tally_strain_gene_hit[species_agg_hit].sum()
total_depth_ratio_strain_species_gene_hit = agg_depth_ratio_strain_gene_hit[species_agg_hit].sum()



mean_depth_ratio_strain_species_gene_hit = agg_depth_ratio_strain_gene_hit[species_agg_hit].apply(trim_mean, proportiontocut=trim_frac)
# TODO: Calculate ratio of depth ratios between "known" species genes and the newly identified genes.
# frac_total_depth_ratio_species_gene_hit = total_depth_ratio_strain_species_gene_hit / total_depth_ratio_species_agg_hit

out = pd.DataFrame(dict(
    strain_selection_threshold=strain_selection_threshold,
    number_species_agg=number_species_agg_hit,
    number_strain_agg_hit=number_strain_agg_hit,
    total_depth_ratio_strain_agg_hit=total_depth_ratio_strain_agg_hit,
    number_strain_species_agg_hit=number_strain_species_agg_hit,
    total_depth_ratio_species_agg_hit=total_depth_ratio_species_agg_hit,
    total_depth_ratio_strain_species_agg_hit=total_depth_ratio_strain_species_agg_hit,
    # frac_total_depth_ratio_species_agg_hit=frac_total_depth_ratio_species_agg_hit,
    mean_depth_ratio_strain_species_agg_hit=mean_depth_ratio_strain_species_agg_hit,
    mean_depth_ratio_species_agg_hit=mean_depth_ratio_species_agg_hit,
    number_strain_gene_hit=number_strain_gene_hit,
    total_depth_ratio_strain_gene_hit=total_depth_ratio_strain_gene_hit,
    number_strain_species_gene_hit=number_strain_species_gene_hit,
    total_depth_ratio_strain_species_gene_hit=total_depth_ratio_strain_species_gene_hit,
    mean_depth_ratio_strain_species_gene_hit=mean_depth_ratio_strain_species_gene_hit,
))

# out.to_csv(...

In [None]:
strain_depth[strain_gene_hit].notna().groupby(gene_meta[aggregate_genes_by]).sum().reindex(species_corr_agg.index).fillna(0).sum()


In [None]:
number_strain_gene_hit

In [None]:
total_depth_ratio_strain_species_gene_hit / total_depth_ratio_species_agg_hit

In [None]:
bins = np.linspace(0, 1, nupairplot1)

plt.hist(species_corr, bins=bins)
None

In [None]:
thresh = 0.9
pseudo = 1e-7

bins = np.linspace(np.log2(pseudo), np.log2(1000), num=101)
plt.hist(np.log2(species_depth_ratio + pseudo), bins=bins)
plt.hist(np.log2(species_depth_ratio[species_corr > thresh] + pseudo), bins=bins)
plt.yscale('log')

plt.axvline(np.log2(1), lw=1, linestyle='--', color='k')
plt.axvline(np.log2(0.2), lw=0.5, linestyle=':', color='k')
plt.axvline(np.log2(5), lw=0.5, linestyle=':', color='k')

None

In [None]:
plt.hist(np.log10(species_gene_depth + 1e-5), bins=100)
plt.hist(np.log10(species_gene_depth.loc[idxwhere(strain_frac.max(1) > 0.95)] + 1e-5), bins=100)
plt.yscale('log')
None

In [None]:
bins = np.linspace(0, 1, num=101)
plt.hist(strain_corr_99[strain], bins=bins)
plt.hist(strain_corr_99[strain][(strain_depth_99[strain] > 0.5) & (strain_depth_99[strain] < 2)], bins=bins)

plt.yscale('log')
None

In [None]:
bins = np.linspace(0, 1, num=101)
plt.hist(strain_corr_95[strain], bins=bins)
plt.hist(strain_corr_95[strain][(strain_depth_95[strain] > 0.5) & (strain_depth_95[strain] < 2)], bins=bins)

plt.yscale('log')
None

In [None]:
bins = np.linspace(0, 1, num=101)
plt.hist(strain_corr_90[strain], bins=bins)
plt.hist(strain_corr_90[strain][(strain_depth_90[strain] > 0.5) & (strain_depth_90[strain] < 2)], bins=bins)

plt.yscale('log')
None

In [None]:
(
    species_depth_ratio[species_corr > 0.95].sum(),
    (strain_depth_95[_strain] * (strain_corr_95[_strain] > q)).groupby(gene_meta.centroid_75).sum().reindex(idxwhere(species_corr > 0.95)).fillna(0).sum(),
)

In [None]:
plt.scatter(
    species_depth_ratio[species_corr > 0.95],
    (strain_depth_95[_strain] * (strain_corr_95[_strain] > q)).groupby(gene_meta.centroid_75).sum().reindex(idxwhere(species_corr > 0.95)).fillna(0),
    s=5, alpha=0.5,
)
plt.plot([0, 5], [0, 5])

In [None]:
thresh = 0.9
pseudo = 1e-7

bins = np.linspace(np.log2(pseudo), np.log2(1000), num=101)
plt.hist(np.log2(strain_depth_99[strain] + pseudo), bins=bins)
plt.hist(np.log2(strain_depth_99.loc[strain_corr_99[strain] > thresh, strain] + pseudo), bins=bins)

plt.axvline(np.log2(1), lw=1, linestyle='--', color='k')
plt.yscale('log')

None

In [None]:
thresh = 0.9
pseudo = 1e-7

bins = np.linspace(np.log2(pseudo), np.log2(1000), num=101)
plt.hist(np.log2(strain_depth_95[strain] + pseudo), bins=bins)
plt.hist(np.log2(strain_depth_95.loc[strain_corr_95[strain] > thresh, strain] + pseudo), bins=bins)

plt.axvline(np.log2(1), lw=1, linestyle='--', color='k')
plt.yscale('log')

None

In [None]:
thresh = 0.9
pseudo = 1e-7

bins = np.linspace(np.log2(pseudo), np.log2(1000), num=101)
plt.hist(np.log2(strain_depth_90[strain] + pseudo), bins=bins)
plt.hist(np.log2(strain_depth_90.loc[strain_corr_90[strain] > thresh, strain] + pseudo), bins=bins)

plt.axvline(np.log2(1), lw=1, linestyle='--', color='k')
plt.yscale('log')

None

In [None]:
thresh = 0.9
pseudo = 1e-7
bins = np.linspace(np.log2(pseudo), np.log2(1000), num=101)


sample_list = idxwhere(strain_frac[strain] > 0.95)
gene_list = idxwhere(strain_corr_99[strain] > thresh)
d2 = strain_depth_99.loc[:, strain].groupby(gene_meta.centroid_75).sum()
d3 = strain_depth_99.loc[gene_list, strain].groupby(gene_meta.centroid_75).sum()

plt.hist(np.log2(d2 + pseudo), bins=bins)
plt.hist(np.log2(d3 + pseudo), bins=bins)

plt.axvline(np.log2(1), lw=1, linestyle='--', color='k')
plt.axvline(np.log2(0.2), lw=0.5, linestyle=':', color='k')
plt.axvline(np.log2(5), lw=0.5, linestyle=':', color='k')



plt.yscale('log')


In [None]:
thresh = 0.9
pseudo = 1e-7
bins = np.linspace(np.log2(pseudo), np.log2(1000), num=101)


sample_list = idxwhere(strain_frac[strain] > 0.95)
gene_list = idxwhere(strain_corr_95[strain] > thresh)
d2 = strain_depth_95.loc[:, strain].groupby(gene_meta.centroid_75).sum()
d3 = strain_depth_95.loc[gene_list, strain].groupby(gene_meta.centroid_75).sum()

plt.hist(np.log2(d2 + pseudo), bins=bins)
plt.hist(np.log2(d3 + pseudo), bins=bins)

plt.axvline(np.log2(1), lw=1, linestyle='--', color='k')
plt.axvline(np.log2(0.2), lw=0.5, linestyle=':', color='k')
plt.axvline(np.log2(5), lw=0.5, linestyle=':', color='k')



plt.yscale('log')


In [None]:
thresh = 0.9
pseudo = 1e-7
bins = np.linspace(np.log2(pseudo), np.log2(1000), num=101)


sample_list = idxwhere(strain_frac[strain] > 0.95)
gene_list = idxwhere(strain_corr_90[strain] > thresh)
d2 = strain_depth_90.loc[:, strain].groupby(gene_meta.centroid_75).sum()
d3 = strain_depth_90.loc[gene_list, strain].groupby(gene_meta.centroid_75).sum()

plt.hist(np.log2(d2 + pseudo), bins=bins)
plt.hist(np.log2(d3 + pseudo), bins=bins)

plt.axvline(np.log2(1), lw=1, linestyle='--', color='k')
plt.axvline(np.log2(0.2), lw=0.5, linestyle=':', color='k')
plt.axvline(np.log2(5), lw=0.5, linestyle=':', color='k')



plt.yscale('log')


In [None]:
thresh = 0.9
pseudo = 1e-7
bins = np.linspace(np.log2(pseudo), np.log2(1000), num=101)

aggregate_at = 'centroid_75'


gene99_list = idxwhere(strain_corr_99[strain] > thresh)
gene_agg_list = list(gene_meta.loc[gene99_list][aggregate_at].unique())

d = pd.DataFrame(dict(
    any99=strain_depth_99.loc[:, strain].groupby(gene_meta[aggregate_at]).sum().loc[gene_agg_list],
    hit99 = strain_depth_99.loc[gene99_list, strain].groupby(gene_meta[aggregate_at]).sum(),
))

plt.hist(np.log2(d['any99']), bins=bins)
plt.hist(np.log2(d['hit99']), bins=bins, alpha=0.7)

None

In [None]:
thresh = 0.9
pseudo = 1e-7
bins = np.linspace(np.log2(pseudo), np.log2(1000), num=101)

aggregate_at = 'centroid_75'


gene95_list = idxwhere(strain_corr_95[strain] > thresh)
gene_agg_list = list(gene_meta.loc[gene95_list][aggregate_at].unique())

d = pd.DataFrame(dict(
    any95=strain_depth_95.loc[:, strain].groupby(gene_meta[aggregate_at]).sum().loc[gene_agg_list],
    hit95 = strain_depth_95.loc[gene95_list, strain].groupby(gene_meta[aggregate_at]).sum(),
))

plt.hist(np.log2(d['any95']), bins=bins)
plt.hist(np.log2(d['hit95']), bins=bins, alpha=0.7)

None

In [None]:
thresh = 0.9
pseudo = 1e-7
bins = np.linspace(np.log2(pseudo), np.log2(1000), num=101)

aggregate_at = 'centroid_75'


gene90_list = idxwhere(strain_corr_90[strain] > thresh)
gene_agg_list = list(gene_meta.loc[gene90_list][aggregate_at].unique())

d = pd.DataFrame(dict(
    any90=strain_depth_90.loc[:, strain].groupby(gene_meta[aggregate_at]).sum().loc[gene_agg_list],
    hit90 = strain_depth_90.loc[gene90_list, strain].groupby(gene_meta[aggregate_at]).sum(),
))

plt.hist(np.log2(d['any90']), bins=bins)
plt.hist(np.log2(d['hit90']), bins=bins, alpha=0.7)

None

In [None]:
thresh = 0.9
pseudo = 1e-7
bins = np.linspace(np.log2(pseudo), np.log2(1000), num=101)


gene99_list = idxwhere(strain_corr_99[strain] > thresh)

for aggregate_at in reversed(['centroid_99', 'centroid_95', 'centroid_90', 'centroid_85', 'centroid_80', 'centroid_75']):
    gene_agg_list = list(gene_meta.loc[gene99_list][aggregate_at].unique())

    d = pd.DataFrame(dict(
        any99=strain_depth_99.loc[:, strain].groupby(gene_meta[aggregate_at]).sum().loc[gene_agg_list],
        hit99 = strain_depth_99.loc[gene99_list, strain].groupby(gene_meta[aggregate_at]).sum(),
    ))
    # plt.hist(np.log2(d['any99']), bins=bins)
    sns.kdeplot(np.log2(d['hit99']), label=aggregate_at)
plt.legend()

In [None]:
thresh = 0.9
pseudo = 1e-7
bins = np.linspace(np.log2(pseudo), np.log2(1000), num=101)


gene95_list = idxwhere(strain_corr_95[strain] > thresh)

for aggregate_at in reversed(['centroid_95', 'centroid_90', 'centroid_85', 'centroid_80', 'centroid_75']):
    gene_agg_list = list(gene_meta.loc[gene95_list][aggregate_at].unique())

    d = pd.DataFrame(dict(
        any95=strain_depth_95.loc[:, strain].groupby(gene_meta[aggregate_at]).sum().loc[gene_agg_list],
        hit95 = strain_depth_95.loc[gene95_list, strain].groupby(gene_meta[aggregate_at]).sum(),
    ))
    # plt.hist(np.log2(d['any95']), bins=bins)
    sns.kdeplot(np.log2(d['hit95']), label=aggregate_at)
plt.legend()

In [None]:
thresh = 0.8
pseudo = 1e-7
bins = np.linspace(np.log2(pseudo), np.log2(1000), num=101)


gene99_list = idxwhere(strain_corr_99[strain] > thresh)

for aggregate_at in reversed(['centroid_99', 'centroid_95', 'centroid_90', 'centroid_85', 'centroid_80', 'centroid_75']):
    gene_agg_list = list(gene_meta.loc[gene99_list][aggregate_at].unique())

    d = pd.DataFrame(dict(
        any99=strain_depth_99.loc[:, strain].groupby(gene_meta[aggregate_at]).sum().loc[gene_agg_list],
        hit99 = strain_depth_99.loc[gene99_list, strain].groupby(gene_meta[aggregate_at]).sum(),
    ))
    sns.kdeplot(d['any99'], log_scale=True, label=aggregate_at)
plt.legend()

In [None]:
thresh = 0.8
pseudo = 1e-7
bins = np.linspace(np.log2(pseudo), np.log2(1000), num=101)


gene95_list = idxwhere(strain_corr_95[strain] > thresh)

for aggregate_at in reversed(['centroid_95', 'centroid_90', 'centroid_85', 'centroid_80', 'centroid_75']):
    gene_agg_list = list(gene_meta.loc[gene95_list][aggregate_at].unique())

    d = pd.DataFrame(dict(
        any95=strain_depth_95.loc[:, strain].groupby(gene_meta[aggregate_at]).sum().loc[gene_agg_list],
        hit95 = strain_depth_95.loc[gene95_list, strain].groupby(gene_meta[aggregate_at]).sum(),
    ))
    # plt.hist(np.log2(d['any95']), bins=bins)
    sns.kdeplot(np.log2(d['any95']), label=aggregate_at)
plt.legend()

In [None]:
thresh = 0.5
pseudo = 1e-7
bins = np.linspace(np.log2(pseudo), np.log2(1000), num=101)


gene99_list = idxwhere(strain_corr_99[strain] > thresh)

for aggregate_at in reversed(['centroid_99', 'centroid_95', 'centroid_90', 'centroid_85', 'centroid_80', 'centroid_75']):
    gene_agg_list = list(gene_meta.loc[gene99_list][aggregate_at].unique())

    d = pd.DataFrame(dict(
        any99=strain_depth_99.loc[:, strain].groupby(gene_meta[aggregate_at]).sum().loc[gene_agg_list],
        hit99 = strain_depth_99.loc[gene99_list, strain].groupby(gene_meta[aggregate_at]).sum(),
    ))
    sns.kdeplot(np.log2(d['any99']), label=aggregate_at)
plt.legend()

In [None]:
thresh = 0.5
pseudo = 1e-7
bins = np.linspace(np.log2(pseudo), np.log2(1000), num=101)


gene95_list = idxwhere(strain_corr_95[strain] > thresh)

for aggregate_at in reversed(['centroid_95', 'centroid_90', 'centroid_85', 'centroid_80', 'centroid_75']):
    gene_agg_list = list(gene_meta.loc[gene95_list][aggregate_at].unique())

    d = pd.DataFrame(dict(
        any95=strain_depth_95.loc[:, strain].groupby(gene_meta[aggregate_at]).sum().loc[gene_agg_list],
        hit95 = strain_depth_95.loc[gene95_list, strain].groupby(gene_meta[aggregate_at]).sum(),
    ))
    sns.kdeplot(np.log2(d['any95']), label=aggregate_at)
plt.legend()

In [None]:
plt.scatter('any99', 'hit99', data=np.log2(d), s=1)

In [None]:
thresh = 0.9

sample_list = idxwhere(strain_frac[strain] > 0.95)
decoy_sample_list = idxwhere((strain_frac[decoy_strains] > 0.95).any(1))

gene_list = idxwhere(strain_corr_99[strain] > thresh)
decoy_gene_list = list(set(idxwhere((strain_corr_99[decoy_strains] > thresh).any(1))) - set(gene_list))

gene75_list = list(gene_meta.loc[gene_list].centroid_75.unique())
decoy_gene75_list = list(set(gene_meta.loc[decoy_gene_list].centroid_75.unique()) - set(gene75_list))

row_colors = pd.Series([0]*len(decoy_sample_list) + [1]*len(sample_list), index=decoy_sample_list + sample_list).map({0: 'black', 1: 'red'})
col_colors = pd.Series([0]*len(decoy_gene75_list) + [1]*len(gene75_list), index=decoy_gene75_list + gene75_list).map({0: 'black', 1: 'red'})

# NOTE: This looks at the summed depth for ALL genes in any strain, not just the correlated genes.
d = gene_depth_99.sel(sample=sample_list + decoy_sample_list, gene_id=gene_list + decoy_gene_list).groupby(gene_meta.centroid_75.to_xarray().sel(gene_id=(gene_list + decoy_gene_list))).sum().to_series().unstack().T
print(d.shape)

sns.clustermap(
    d, metric='cosine',
    norm=mpl.colors.SymLogNorm(linthresh=1e-3),
    row_colors=row_colors,
    col_colors=col_colors,
)

In [None]:
strain_corr_75

In [None]:
cutoff = 0.95
strain_hit = set(idxwhere(strain_corr_75[strain] > cutoff))
species_hit = set(idxwhere(species_corr > cutoff))

len(species_hit), len(strain_hit), len(species_hit | strain_hit), len(species_hit - strain_hit), len(strain_hit - species_hit)

In [None]:
x, y = align_indexes(strain_corr_75[strain], species_corr)

plt.scatter(x, y, s=1)

In [None]:
pseudo = 1e-5
d = strain_depth[strain] + pseudo
c = strain_corr[strain]

bins = np.linspace(np.log2(pseudo), np.log2(d.max()), num=101)
plt.hist(np.log2(d), bins=bins)
plt.hist(np.log2(d[c > 0.5]), bins=bins, alpha=0.5)

plt.yscale('log')

In [None]:
pseudo = 1e-5
cutoff = 0.5
d = strain_depth[strain] + pseudo
c = strain_corr[strain]

bins = np.linspace(np.log2(pseudo), np.log2(d.max()), num=101)
plt.hist(np.log2(d.groupby(gene_meta.centroid_75).sum()), bins=bins)
plt.hist(np.log2(d[c > cutoff].groupby(gene_meta.centroid_75).sum()), bins=bins, alpha=0.5)

plt.axvline(np.log2(1), lw=1, linestyle='--', color='k')
plt.axvline(np.log2(1/5), lw=0.5, linestyle=':', color='k')

plt.yscale('log')

In [None]:
pseudo = 1e-5
cutoff = 0.5
d = strain_depth[strain] + pseudo
c = strain_corr[strain]

bins = None  # np.linspace(np.log2(pseudo), np.log2(d.max()), num=101)
# plt.hist(np.log2(d.groupby(gene_meta.centroid_75).sum()), bins=bins)
# plt.hist(np.log2(d[c > cutoff].groupby(gene_meta.centroid_75).sum()), bins=bins, alpha=0.5)

plt.figure()
plt.hist(np.log2(d.groupby(gene_meta.marker_id).sum()), bins=bins, alpha=0.5)
plt.hist(np.log2(d[(c > cutoff)].groupby(gene_meta.marker_id).sum()), bins=bins, alpha=0.5)
plt.axvline(0, lw=1, linestyle='--', color='k')
plt.yscale('log')


plt.figure()
plt.scatter(x='_all', y='_hit', data=pd.DataFrame(dict(_all=d.groupby(gene_meta.marker_id).sum(), _hit=d[(c > cutoff)].groupby(gene_meta.marker_id).sum())))
plt.plot([0, 2], [0, 2])

In [None]:
species_gene_list = idxwhere(species_corr > 0.95)
len(species_gene_list)

In [None]:
mean_gene_depth = pd.read_table(f'data_temp/sp-{species_id}.hmp2.a.r.proc.midas.species_gene_depth_ratio.tsv', names=['sample', 'depth'], index_col='sample').squeeze()

In [None]:
bins = np.linspace(-5, 5, num=101)

species_hit = idxwhere(species_corr > 0.95)
sns.kdeplot(np.log2(mean_gene_depth[species_hit] + 1e-5), label='all', color='k', lw=2)

for cutoff in [0.95, 0.9, 0.8, 0.7, 0.5]:
    sns.kdeplot(np.log2(strain_depth[strain][strain_corr[strain] > cutoff].groupby(gene_meta.centroid_75).sum().reindex(species_hit).fillna(0) + 1e-5), label=cutoff)
plt.legend()


In [None]:
bins = np.linspace(-5, 5, num=101)
pseudo = 1e-1

species_hit = idxwhere(species_corr > 0.95)

for cutoff in [0.95, 0.9, 0.8, 0.7, 0.5]:
    plt.scatter(
        np.log2(mean_gene_depth[species_hit] + pseudo),
        np.log2(strain_depth[strain][strain_corr[strain] > cutoff].groupby(gene_meta.centroid_75).sum().reindex(species_hit).fillna(0) + pseudo),
        label=cutoff,
        s=3,
        alpha=0.5,
    )
plt.plot([np.log2(pseudo), np.log2(1e1)], [np.log2(pseudo), np.log2(1e1)])
plt.legend()


In [None]:
strain_depth[strain][strain_corr[strain] > cutoff].groupby(gene_meta.centroid_75).sum().reindex(species_hit).isna().sum()

In [None]:
bins = np.linspace(0, 210)  # np.linspace(-5, 5, num=101)

for cutoff in reversed([0.95, 0.9, 0.8, 0.7, 0.5]):
    d = strain_depth[strain][strain_corr[strain] > cutoff].groupby(gene_meta.centroid_75).sum()
    plt.hist(d, bins=bins)
    print(cutoff, d.sum())

plt.yscale('log')

In [None]:
print(strain_depth[strain][strain_corr[strain] > 0.7].groupby(gene_meta.centroid_75).sum().sum())

In [None]:
for cutoff in [0.99, 0.95, 0.9, 0.8, 0.7, 0.5, 0]:
    x = mean_gene_depth[species_hit] + pseudo
    y = strain_depth[strain][strain_corr[strain] > cutoff].groupby(gene_meta.centroid_75).sum().reindex(species_hit).fillna(0)
    print(cutoff, y.sum(), y.sum() / x.sum(), x.sum())

In [None]:
plt.hist(np.log2(strain_depth[strain][strain_corr[strain] > 0.7].groupby(gene_meta.centroid_75).sum().fillna(0)), bins=50)

In [None]:
gene_depth_75 = gene_depth.groupby(gene_meta.loc[gene_depth.gene_id].centroid_75.to_xarray()).sum().rename(dict(centroid_75='gene_id'))

In [None]:
d = (
    gene_depth
    .sel(gene_id=idxwhere(strain_corr[strain] > 0.7), sample=idxwhere(strain_frac[strain] > 0.95))
    .to_series()
    .unstack()
    .groupby(gene_meta.centroid_75)
    .sum()
)
# d = d / species_gene_depth.loc[d.columns]

In [None]:
sns.clustermap(d, norm=mpl.colors.PowerNorm(1/7, vmin=0.1, vmax=100), metric='cosine')

In [None]:
plt.hist(np.log2(d.mean(1)), bins=51)

In [None]:
sns.clustermap(np.log2(d), metric='cosine')