In [None]:
%load_ext autoreload

In [None]:
import os as _os
_os.chdir(_os.environ['PROJECT_ROOT'])
_os.path.realpath(_os.path.curdir)

#### Imports

In [None]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
from lib.pandas_util import idxwhere, aligned_index, align_indexes, invert_mapping
import lib.thisproject.data
import matplotlib as mpl
import lib.plot
import statsmodels as sm
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm
import subprocess
from tempfile import mkstemp
import time
import subprocess
from itertools import chain
import os
from itertools import product
from mpl_toolkits.axes_grid1 import make_axes_locatable
import sfacts as sf
import lib.thisproject.data

In [None]:
species_depth = lib.thisproject.data.load_species_depth("data/group/xjin_hmp2/r.proc.gtpro.species_depth.tsv")
# species_depth_subset = lib.thisproject.data.load_species_depth(path['all_species_depth_subset'])
rabund = species_depth.apply(lambda x: x / x.sum(), axis=1)
# rabund_subset = species_depth_subset.apply(lambda x: x / x.sum(), axis=1)

n_species = 40
top_species = (species_depth > 1e-3).sum().sort_values(ascending=False).head(n_species).index

fig, axs = plt.subplots(n_species, figsize=(5, 0.25 * n_species), sharex=True, sharey=True)

bins = np.logspace(-3, 4, num=51)

for species_id, ax in zip(top_species, axs):
    # ax.hist(rabund_subset[species_id], bins=bins, alpha=0.7)
    ax.hist(species_depth[species_id], bins=bins, alpha=0.7)
    ax.set_xscale('log')
    prevalence = (species_depth[species_id] > 1e-3).mean()
    ax.set_title("")
    # ax.set_xticks()
    # ax.set_yticks()
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.patch.set_alpha(0.0)
    for spine in ['left', 'right', 'top', 'bottom']:
        ax.spines[spine].set_visible(False)
    ax.annotate(f'{species_id} ({prevalence:0.0%})', xy=(0.05, 0.1), ha='left', xycoords="axes fraction")
    ax.set_xlim(left=1e-7)
    ax.set_ylim(top=300)
    
ax.xaxis.set_visible(True)
ax.spines['bottom'].set_visible(True)

# fig.subplots_adjust(hspace=-0.75)

In [None]:
bins = [0] + list(np.logspace(0, 4))
plt.hist((species_depth > 0.5).sum(), bins=bins)
plt.xscale('symlog', linthresh=1.0, linscale=0.1)

In [None]:
# species_id = '100099'  # '100236'

species_taxonomy = lib.thisproject.data.load_species_taxonomy("ref/gtpro/species_taxonomy_ext.tsv")
species_taxonomy.assign(num_samples_with_depth_gt_half=(species_depth > 0.5).sum()).fillna(0).sort_values('num_samples_with_depth_gt_half', ascending=False).head(10).drop(columns=['taxonomy_string'])

In [None]:
reference_meta = pd.read_table('ref/uhgg_genomes_all_4644.tsv', index_col='Genome').rename_axis(index='genome_id').rename(lambda s: 'UHGG' + s[len("GUT_GENOME"):])
reference_meta

In [None]:
reference_meta.groupby('Species_rep').MGnify_accession.count().rename_axis(index='genome_id').rename(lambda s: 'UHGG' + s[len("GUT_GENOME"):]).sort_values()

In [None]:
ref_table0 = pd.read_table('ref/uhgg_genomes_all_4644.tsv', index_col='Genome').rename_axis(index='genome_id').rename(lambda s: 'UHGG' + s[len("GUT_GENOME"):])
ref_table1 = pd.read_table('ref/midasdb_uhgg/genomes.tsv').set_index('genome').rename(lambda s: 'UHGG' + s[len("GUT_GENOME"):])
assert set(ref_table0.index) == set(ref_table1.index)

midasdb_genome_data = ref_table0.join(ref_table1)
midasdb_genome_data

In [None]:
midasdb_genome_tally_total = pd.read_table('ref/midasdb_uhgg/genomes.tsv').groupby('species').genome.count().rename(str)
midasdb_genome_tally_isolate = pd.read_table('ref/midasdb_uhgg/genomes.tsv').groupby('species').genome.count().rename(str)

In [None]:
num_strains_raw = {}
for species_id in species_depth.columns:
    try:
        _spgc_meta = pd.read_table(f"data/group/xjin_hmp2/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.gene99-v22-agg75.spgc.strain_meta.tsv", index_col='strain').rename(str)
    except FileNotFoundError as err:
        print(err)
        continue
    num_strains_raw[species_id] = _spgc_meta.shape[0]
    
num_strains_raw = pd.Series(num_strains_raw)

In [None]:
d = species_taxonomy.assign(num_samples_with_depth_gt_tenth=(species_depth > 0.1).sum(), num_strains_raw=num_strains_raw).fillna(0)
plt.scatter('num_samples_with_depth_gt_tenth', 'num_strains_raw', data=d)
plt.yscale('log')
plt.xscale('log')
plt.xlabel('num_samples_with_depth_gt_tenth')
plt.ylabel('num_strains_raw')

In [None]:
all_strain_meta = []

for species_id in tqdm(species_depth.columns):
    try:
        _spgc_meta = pd.read_table(f"data/group/xjin_hmp2/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.gene99-v22-agg75.spgc.strain_meta.tsv", index_col='strain').rename(str)
    except FileNotFoundError as err:
        pass # print(species_id, "File missing.")
    else:
        x = _spgc_meta[_spgc_meta.species_gene_frac > 0.9].num_genes
        if len(x) < 1:
            pass # print(species_id, "No strains have >90% of species genes.")
            num_strains_filt[species_id] = 0
        else:
            _df, _loc, _scale = sp.stats.t.fit(x.values, fix_df=2)
            _dist0 = sp.stats.t(_df, _loc, _scale)
            _dist1 = sp.stats.norm(_loc, _scale)
            thresh_max_num_uhgg_genes = _dist1.ppf(0.999)
            thresh_min_num_uhgg_genes = _dist1.ppf(0.001)
        all_strain_meta.append(_spgc_meta.assign(
            species_id=species_id,
            thresh_max_num_uhgg_genes=thresh_max_num_uhgg_genes,
            thresh_min_num_uhgg_genes=thresh_min_num_uhgg_genes,
            passes_filt=lambda x: (True
                & (x.sum_depth > 1)
                & (x.species_gene_frac > 0.9)
                & (x.num_genes <= x.thresh_max_num_uhgg_genes)
                & (x.num_genes >= x.thresh_min_num_uhgg_genes)
            )
        ))

all_strain_meta = pd.concat(all_strain_meta).assign(num_sample=lambda x: x.num_sample.astype(int), num_genes=lambda x: x.num_genes.astype(int))

In [None]:
plt.hist(all_strain_meta.sum_depth, bins=np.logspace(-3, 5))
plt.xscale('log')
plt.axvline(1.0, lw=1, linestyle='--', color='k')
None

In [None]:
x = all_strain_meta.groupby('species_id')[['thresh_max_num_uhgg_genes', 'thresh_min_num_uhgg_genes']].first().assign(width=lambda x: x.thresh_max_num_uhgg_genes - x.thresh_min_num_uhgg_genes)
x.width.sort_values().head(10)

In [None]:
all_strain_meta.num_sample.astype(int)

In [None]:
all_strain_meta.passes_filt.sum()

In [None]:
pg = sns.pairplot(
    (
        all_strain_meta
        .assign(
            num_sample=lambda x: np.log10(x.num_sample + 1e-1),
            max_depth=lambda x: np.log10(x.max_depth + 1e-3),
            sum_depth=lambda x: np.log10(x.sum_depth + 1e-3),
            species_gene_frac_n1ml=lambda x: -np.log10(1 - x.species_gene_frac + 1e-3),
            num_genes=lambda x: np.log10(x.num_genes + 1),
            strain_metagenotype_entropy=lambda x: np.log10(x.strain_metagenotype_entropy + 1e-4),
        )
    ),
    vars=['num_sample', 'max_depth', 'sum_depth', 'species_gene_frac_n1ml', 'num_genes', 'strain_metagenotype_entropy'],
    hue='passes_filt',
    kind='scatter',
    plot_kws=dict(s=4, lw=0),
)

In [None]:
all_species_meta = (
    all_strain_meta
    .groupby('species_id')
    .apply(lambda x: pd.Series(dict(num_strains_filt=x.passes_filt.sum(), num_strains_raw=x.shape[0])))
    .assign(
        num_samples_with_depth_gt_half=(species_depth > 0.5).sum(),
        num_filtered_out=lambda x: x.num_strains_raw - x.num_strains_filt,
        num_strains_midasdb=midasdb_genome_tally_total,
        num_isolates_midasdb=midasdb_genome_tally_isolate,
    )
    .join(species_taxonomy)
    .sort_values('num_strains_filt', ascending=False)
    )
all_species_meta

In [None]:
o__order = all_species_meta['o__'].value_counts().index.values
top_o = o__order[:5]
o__palette = lib.plot.construct_ordered_palette(top_o)
plt.scatter('num_strains_raw', 'num_strains_filt', data=all_species_meta, c='num_samples_with_depth_gt_half', norm=mpl.colors.SymLogNorm(1.0))
plt.colorbar()
plt.xlabel('num_strains_raw')
plt.ylabel('num_strains_filt')
plt.xscale('symlog')
plt.yscale('symlog')

In [None]:
o__order = all_species_meta['o__'].value_counts().index.values
top_o = o__order[:10]
o__palette = lib.plot.construct_ordered_palette(top_o, cm='tab20')

for o__ in o__order:
    plt.scatter(
        'num_strains_raw',
        'num_strains_filt',
        data=all_species_meta[all_species_meta.o__ == o__],
        c=o__palette[o__],
        label=o__,
    )
plt.xlabel('num_strains_raw')
plt.ylabel('num_strains_filt')
plt.xscale('symlog')
plt.yscale('symlog')
plt.legend(bbox_to_anchor=(1, 1))

In [None]:
o__order = all_species_meta['o__'].value_counts().index.values
top_o = o__order[:10]
o__palette = lib.plot.construct_ordered_palette(top_o, cm='tab20')

for o__ in o__order:
    plt.scatter(
        'num_strains_midasdb',
        'num_strains_filt',
        data=all_species_meta[all_species_meta.o__ == o__],
        c=o__palette[o__],
        label=o__,
    )
plt.ylabel('num_strains_filt')
plt.xlabel('num_strains_midasdb')
plt.xscale('symlog')
plt.yscale('symlog')
plt.plot([1, 100], [1, 100])
plt.legend(bbox_to_anchor=(1, 1))

In [None]:
o__order = all_species_meta['o__'].value_counts().index.values
top_o = o__order[:10]
o__palette = lib.plot.construct_ordered_palette(top_o, cm='tab20')

for o__ in o__order:
    plt.scatter(
        'num_samples_with_depth_gt_half',
        'num_strains_midasdb',
        data=all_species_meta[all_species_meta.o__ == o__],
        c=o__palette[o__],
        label=o__,
    )
plt.ylabel('num_strains_midasdb')
plt.xlabel('num_samples_with_depth_gt_half')
plt.xscale('symlog')
plt.yscale('symlog')
# plt.plot([1, 100], [1, 100])
plt.legend(bbox_to_anchor=(1, 1))

In [None]:
all_species_meta.sort_values('num_filtered_out', ascending=False)

In [None]:
all_species_meta.sort_values('num_strains_filt', ascending=False).head(30)

In [None]:
all_species_meta.assign(spgc_strain_to_midas_isolate_ratio=lambda x: x.num_strains_filt / x.num_isolates_midasdb).sort_values('spgc_strain_to_midas_isolate_ratio', ascending=False).head(50)