In [None]:
%load_ext autoreload

In [None]:
import os as _os
_os.chdir(_os.environ['PROJECT_ROOT'])
_os.path.realpath(_os.path.curdir)

#### Imports

In [None]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
from lib.pandas_util import idxwhere, aligned_index, align_indexes, invert_mapping
import lib.thisproject.data
import matplotlib as mpl
import lib.plot
import statsmodels as sm
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm
import subprocess
from tempfile import mkstemp
import time
import subprocess
from itertools import chain
import os
from itertools import product
from mpl_toolkits.axes_grid1 import make_axes_locatable
import lib.thisproject.data

### Set Style

In [None]:
sns.set_context('talk')
plt.rcParams['figure.dpi'] = 50

In [None]:
species_taxonomy = lib.thisproject.data.load_species_taxonomy("ref/gtpro/species_taxonomy_ext.tsv")
species_taxonomy

In [None]:
genomes_meta = (
    pd.read_table('ref/midasdb_uhgg/genomes.tsv', index_col='genome')
    .rename_axis(index='genome_id')
    .rename(lambda s: 'UHGG' + s[len("GUT_GENOME"):])
)
genomes_meta

In [None]:
reference_meta = (
    pd.read_table('ref/uhgg_genomes_all_4644.tsv', dtype={'species': str}, index_col='Genome')
    .rename_axis(index='genome_id')
    .rename(lambda s: 'UHGG' + s[len("GUT_GENOME"):])
    .join(genomes_meta)
)
reference_meta

In [None]:
# - How many samples
group = 'xjin_hmp2'  # But drop xjin_ samples

sample_list = pd.read_table('meta/mgen_group.tsv')[lambda x: (x.mgen_group == group) & (~x.mgen_id.str.startswith('xjin_'))].mgen_id.to_list()
assert len(sample_list) == len(set(sample_list))
len(sample_list)

In [None]:
# - How many species analyzed
species_list1 = pd.read_table('meta/species_group.tsv')[lambda x: (x.species_group_id == group)].species_id.to_list()
assert len(species_list1) == len(set(species_list1))
len(species_list1)

In [None]:
# - How many species found in xjin_hmp2?
# TODO: Remind myself of what my species filters were.
species_list2 = [str(x) for x in lib.pandas_util.read_list('data/group/xjin_hmp2/r.proc.pangenomes/pangenomes.species')]
len(species_list2)

In [None]:
# - For each species:
species_id = '100022'

In [None]:
#   - How many species-x-samples pairs at sufficient depth (1x)
species_depth = pd.read_table(f'data/group/xjin_hmp2/species/sp-{species_id}/r.proc.gene99-v22-agg75.spgc_specgene-ref-t25-p95.species_depth.tsv', names=['sample', 'depth'], index_col=['sample']).depth.reindex(sample_list, fill_value=0)
sample_list2 = idxwhere((species_depth > 1))
len(sample_list2)

In [None]:
#   - How many strains were these collapsed into (with at least 1x depth)
strain_frac = pd.read_table(f'data/group/xjin_hmp2/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.comm.tsv', index_col=['sample', 'strain']).community.unstack('strain').loc[sample_list2]
strain_depth = (strain_frac.T * species_depth.loc[sample_list2]).T
strain_list = idxwhere((strain_depth > 1).any())
len(strain_list)

In [None]:
#   - How many of these had at least one "pure" sample
strain_list2 = idxwhere((strain_frac[strain_list] > 0.95).any())
len(strain_list2)

In [None]:
#   - How many passed "species gene frac" threshold?
strain_meta = pd.read_table(f'data/group/xjin_hmp2/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.gene99-v22-agg75.spgc.strain_meta.tsv', index_col='strain')
strain_list3 = idxwhere(strain_meta.reindex(strain_list2, fill_value=0).species_gene_frac > 0.9)
len(strain_list3)

In [None]:
#   - How many passed gene count filtering? (These are our final numbers)
x = strain_meta.loc[strain_list3].num_genes
_df, _loc, _scale = sp.stats.t.fit(x.values, fix_df=2)
_dist0 = sp.stats.t(_df, _loc, _scale)
_dist1 = sp.stats.norm(_loc, _scale)
thresh_max_num_uhgg_genes = _dist1.ppf(0.999)
thresh_min_num_uhgg_genes = _dist1.ppf(0.001)

strain_list4 = idxwhere((x > thresh_min_num_uhgg_genes) & (x < thresh_max_num_uhgg_genes))
len(strain_list4)

In [None]:
species_strain_counts = {}
strain_details = []

for species_id in tqdm(species_list2):
    #   - How many species-x-samples pairs at sufficient depth (1x)
    species_depth = pd.read_table(f'data/group/xjin_hmp2/species/sp-{species_id}/r.proc.gene99-v22-agg75.spgc_specgene-ref-t25-p95.species_depth.tsv', names=['sample', 'depth'], index_col=['sample']).depth.reindex(sample_list, fill_value=0)
    sample_list2 = idxwhere((species_depth > 0.05))
    
    #   - How many strains were these collapsed into (with at least 1x depth)
    try:
        strain_frac = pd.read_table(f'data/group/xjin_hmp2/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.comm.tsv', index_col=['sample', 'strain']).community.unstack('strain').reindex(sample_list2, fill_value=0)
    except FileNotFoundError as err:
        print(f"SFacts output missing for {species_id}.")
        print(err)
        continue

    strain_list0 = idxwhere((strain_frac > 0.5).any())
    strain_depth = (strain_frac.T * species_depth.loc[sample_list2]).T    
    #   - How many of these had at least one "pure" sample
    strain_list1 = idxwhere((strain_frac[strain_list0] > 0.95).any())
    
    try:
        strain_meta = pd.read_table(f'data/group/xjin_hmp2/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.gene99-v22-agg75.spgc.strain_meta.tsv', index_col='strain').loc[strain_list1]
    except FileNotFoundError as err:
        print(f"SPGC output missing for {species_id}.")
        print(err)
        continue

    #   - How many of these had a total depth of >1x?
    strain_list2 = idxwhere(strain_meta.sum_depth > 1)  # 

    #   - How many passed "species gene frac" threshold?
    strain_list3 = idxwhere(strain_meta.reindex(strain_list2, fill_value=0).species_gene_frac > 0.9)
    
    #   - How many passed gene count filtering? (These are our final numbers)
    x = strain_meta.loc[strain_list3].num_genes
    if len(x) < 1:
        strain_list4 = []
    else:
        _df, _loc, _scale = sp.stats.t.fit(x.values, fix_df=2)
        _dist0 = sp.stats.t(_df, _loc, _scale)
        _dist1 = sp.stats.norm(_loc, _scale)
        thresh_max_num_uhgg_genes = _dist1.ppf(0.999)
        thresh_min_num_uhgg_genes = _dist1.ppf(0.001)
        strain_list4 = idxwhere((x > thresh_min_num_uhgg_genes) & (x < thresh_max_num_uhgg_genes))

    species_strain_counts[species_id] = pd.Series(dict(
        num_species_samples=len(sample_list2),  # Species depth >1x
        num_inferred_strains=len(strain_list0),  # "Inferred" means >50% in at least one sample.
        num_strains_with_pure_sample=len(strain_list1),  # At least one "pure" sample
        num_strains_with_sufficient_depth=len(strain_list2),  # >1x depth across all samples  # NOTE: This includes xjin samples.
        num_complete_spgc=len(strain_list3),  # Species gene frac >90%
        num_passing_spgc=len(strain_list4),  # Not a gene count outlier.
    ))
    strain_details.append(pd.DataFrame(index=strain_list0).assign(
        species=species_id,
        strain=lambda x: x.index,
        has_inference=True,
        has_pure_sample=lambda x: x.index.isin(strain_list1),
        has_sufficient_depth=lambda x: x.index.isin(strain_list2),
        has_species_genes=lambda x: x.index.isin(strain_list3),
        has_reasonable_gene_count=lambda x: x.index.isin(strain_list4),
    ))
species_strain_counts = pd.DataFrame(species_strain_counts).T
strain_details = pd.concat(strain_details).set_index(['species', 'strain'])

In [None]:
species_strain_counts.sum()

In [None]:
strain_details.sum()

In [None]:
species_strain_counts.groupby(species_taxonomy.apply(lambda x: x.p__ + ';' + x.c__, axis=1)).sum().sort_values('num_passing_spgc', ascending=False)

In [None]:
species_taxonomy[lambda x: x.c__ == 'c__Coriobacteriia']

In [None]:
species_strain_counts.groupby(species_taxonomy.apply(lambda x: x.p__ + ';' + x.c__, axis=1)).sum()#.sort_values('num_passing_spgc', ascending=False)

In [None]:
d = (
    species_strain_counts
    .groupby(
        # species_taxonomy.apply(lambda x: x.p__ + ';' + x.c__, axis=1)
        species_taxonomy.apply(lambda x: x.p__, axis=1)
    )
    .sum()
    .sort_values('num_passing_spgc', ascending=False)
    .T
)
_palette = lib.plot.construct_ordered_palette(d.columns, cm='rainbow')

fig, ax = plt.subplots(figsize=(3, 6))
for taxon in d.columns:
    ax.plot(d[taxon], c=_palette[taxon], label=taxon, lw=3, alpha=0.8)
# d.plot(kind='line')
ax.set_yscale('symlog', linthresh=1)
ax.set_ylim(1)
lib.plot.rotate_xticklabels()
ax.legend(bbox_to_anchor=(1, 1))

In [None]:
d = (
    species_strain_counts
    .groupby(
        # species_taxonomy.apply(lambda x: x.p__ + ';' + x.c__, axis=1)
        species_taxonomy.apply(lambda x: x.p__, axis=1)
    )
    .sum()
    .sort_values('num_passing_spgc', ascending=False)
    .assign(phylum=lambda x: x.index.to_series().str[len('p__'):])
    .set_index('phylum')
)
_palette = lib.plot.construct_ordered_palette(d.columns, cm='Spectral')

fig = plt.figure(figsize=(9, 5))
for level in d.columns:
    plt.bar(d.index, d[level], color=_palette[level])

    
plt.yscale('log')
plt.ylabel('Count')
# plt.legend(bbox_to_anchor=(1, 1))
plt.ylim(0.1)
plt.yticks(np.logspace(0, 5, num=6), minor=False)
plt.yticks([], minor=True)
lib.plot.rotate_xticklabels(rotation=25)

fig, ax = plt.subplots()
_rename_levels = dict(
    num_species_samples='Species-Sample Pairs',  # Species depth >1x
    num_inferred_strains='Strains Inferred',  # "Inferred" means >50% in at least one sample.
    num_strains_with_pure_sample='Strains w/ Pure Samples (>95%)',  # At least one "pure" sample
    num_strains_with_sufficient_depth='+ Sufficient Depth (>1x)',  # >1x depth across all samples  # NOTE: This includes xjin samples.
    num_complete_spgc='+ High "Completeness" (>90%)',  # Species gene frac >90%
    num_passing_spgc='+ Appropriate Gene Count',  # Not a gene count outlier.
)
for level in d.columns:
    ax.bar(d.index, 0, color=_palette[level], label=_rename_levels[level])
ax.legend()
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)


# fig, ax = plt.subplots()
# for level, height in zip(_palette, reversed(np.linspace(0, 1, num=len(_palette) + 1))):
#     ax.bar(0, height, color=_palette[level])
#     ax.annotate(level, xy=(0, height), ha='center', va='top')

In [None]:
d = (
    species_strain_counts
    .groupby(
        species_taxonomy.apply(lambda x: x.p__ + ';' + x.c__, axis=1)
        # species_taxonomy.apply(lambda x: x.p__, axis=1)
    )
    .sum()
    .sort_values('num_passing_spgc', ascending=False)
)
_palette = lib.plot.construct_ordered_palette(d.columns, cm='Spectral')

for level in d.columns:
    plt.bar(d.index, d[level], color=_palette[level], label=level)

# for level in reversed(d.columns):
#     plt.bar(d.index, 0, color=_palette[level], label=level)

plt.yscale('log')
plt.legend(bbox_to_anchor=(1, 1))
plt.ylim(0.1)
plt.yticks(np.logspace(0, 5, num=6), minor=False)
plt.yticks([], minor=True)
lib.plot.rotate_xticklabels()

In [None]:
d.sort_values('num_passing_spgc', ascending=False).head(20).index

In [None]:
# _species_list = ['100022', '102506', '102492']
# taxon_list = species_taxonomy.loc[_species_list].taxonomy_string

d = (
    species_strain_counts
    .groupby(
        # species_taxonomy.apply(lambda x: x.p__ + ';' + x.c__, axis=1)
        species_taxonomy.taxonomy_string
    )
    .sum()
    .sort_values('num_passing_spgc', ascending=False)
)
taxon_list = d.sort_values('num_species_samples', ascending=False).head(40).index
_palette = lib.plot.construct_ordered_palette(d.columns, cm='rainbow')

fig, ax = plt.subplots(figsize=(15, 5))
for level in d.columns:
    plt.bar(taxon_list, d.loc[taxon_list, level], color=_palette[level], label=level)

# for level in reversed(d.columns):
#     plt.bar(d.index, 0, color=_palette[level], label=level)
plt.legend(bbox_to_anchor=(1, 1))

plt.yscale('symlog', linthresh=1, linscale=0.1)
plt.ylim(0.1)
plt.yticks(np.logspace(0, 5, num=6), minor=False)
plt.yticks([], minor=True)
lib.plot.rotate_xticklabels()

In [None]:
species_taxonomy[species_taxonomy.s__.str.contains('fragilis')]

In [None]:
d = (
    species_strain_counts
    .join(
        reference_meta.groupby('species').Genome_type.value_counts().unstack(fill_value=0).rename(str)
    )
    .join(species_taxonomy)
)

d.sort_values('num_species_samples', ascending=False).head(40)

In [None]:
d = (
    species_strain_counts
    .join(
        reference_meta.groupby('species').Genome_type.value_counts().unstack(fill_value=0).rename(str)
    )
    .join(species_taxonomy)
    .sort_values('num_passing_spgc')
    # .set_index('taxonomy_string')
    [['num_passing_spgc', 'Isolate', 'MAG']]
)

taxon_list = d.sort_values('num_passing_spgc', ascending=False).head(40).index
# _palette = lib.plot.construct_ordered_palette(d.columns, cm='rainbow')

fig, ax = plt.subplots(figsize=(15, 5))
d.loc[taxon_list].plot.bar(ax=ax)
ax.set_yscale('log')