In [None]:
%load_ext autoreload

In [None]:
import os as _os
_os.chdir(_os.environ['PROJECT_ROOT'])
_os.path.realpath(_os.path.curdir)

In [None]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
from lib.pandas_util import idxwhere, align_indexes, invert_mapping
import matplotlib as mpl
import lib.plot
import statsmodels as sm
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm
import subprocess
from tempfile import mkstemp
import time
import subprocess
from itertools import chain
import os

In [None]:
import sfacts as sf

In [None]:
import lib.thisproject.data

In [None]:
sns.set_context('talk')
plt.rcParams['figure.dpi'] = 50

In [None]:
group_subset = 'xjin'
group = 'xjin_hmp2'
stemA = 'r.proc'

path = {}

path.update(dict(
    species_taxonomy="ref/gtpro/species_taxonomy_ext.tsv",
    all_species_depth=f"data/group/{group_subset}/{stemA}.gtpro.species_depth.tsv",
    midasdb_genomes="ref/uhgg_genomes_all_4644.tsv",
    strain_genomes="meta/genome.tsv",
))

path_exists = {}
for p in path:
    path_exists[path[p]] = os.path.exists(path[p])

assert all(path_exists.values()), '\n'.join(["Missing files:"] + [p for p in path_exists if not path_exists[p]])

In [None]:
species_depth = lib.thisproject.data.load_species_depth(path['all_species_depth'])
rabund = species_depth.apply(lambda x: x / x.sum(), axis=1)

n_species = 90
top_species = (rabund > 1e-5).sum().sort_values(ascending=False).head(n_species).index

fig, axs = plt.subplots(n_species, figsize=(10, 0.5 * n_species), sharex=True, sharey=True)

bins = np.logspace(-7, 0, num=51)

for species_id, ax in zip(top_species, axs):
    ax.hist(rabund[species_id], bins=bins, alpha=0.7)
    ax.set_xscale('log')
    prevalence = (rabund[species_id] > 1e-5).mean()
    ax.set_title("")
    # ax.set_xticks()
    # ax.set_yticks()
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.patch.set_alpha(0.0)
    for spine in ['left', 'right', 'top', 'bottom']:
        ax.spines[spine].set_visible(False)
    ax.annotate(f'{species_id} ({prevalence:0.0%})', xy=(0.05, 0.1), ha='left', xycoords="axes fraction")
    ax.set_xlim(left=1e-7)
    ax.set_ylim(top=300)
    
ax.xaxis.set_visible(True)
ax.spines['bottom'].set_visible(True)

fig.subplots_adjust(hspace=-0.75)

In [None]:
species = '102492'

species_taxonomy = lib.thisproject.data.load_species_taxonomy(path["species_taxonomy"])
species_taxonomy.loc[species]

In [None]:
strain_genome = pd.read_table(path["strain_genomes"], dtype='str')
strain_genome[strain_genome.species_id == species]

In [None]:
strain_genome_ids = strain_genome[strain_genome.species_id == species].genome_id
print(strain_genome_ids)
strain_genome_id = strain_genome_ids.tolist()[0]
assert strain_genome_ids.shape[0] == 1

In [None]:
path.update(dict(
    # uhgg_x_strain=f'data/species/sp-{species}/genome/midas_uhgg_pangenome.{strain_genome_id}-blastp.tsv',
    strain_x_uhgg=f'data/species/sp-{species}/genome/{strain_genome_id}.midas_uhgg_pangenome-blastp.tsv',
    strain_x_strain=f'data/species/sp-{species}/genome/{strain_genome_id}.{strain_genome_id}-blastp.tsv',
))

path_exists = {}
for p in path:
    path_exists[path[p]] = os.path.exists(path[p])

assert all(path_exists.values()), '\n'.join(["Missing files:"] + [p for p in path_exists if not path_exists[p]])

In [None]:
# Default file path forming for interactive use.

stemB = 'filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts43-s85-seed0'
stemC = 'sfacts42-seed0'
spgc_params = 'e100'
centroid = 75

path.update(dict(
    flag=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.refit-{stemC}.gene{centroid}.spgc-{spgc_params}.strain_files.flag",
    fit=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.world.nc",
    refit=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.refit-{stemC}.world.nc",
    strain_correlation=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.gene{centroid}.spgc-{spgc_params}.strain_correlation.tsv",
    strain_depth_ratio=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.gene{centroid}.spgc-{spgc_params}.strain_depth_ratio.tsv",
    strain_fraction=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.comm.tsv",
    # species_gene=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.gene{centroid}.spgc-{spgc_params}.species_gene.list",
    species_gene_mean_depth=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.gene{centroid}.spgc-{spgc_params}.species_depth.tsv",
    species_gtpro_depth=f"data/group/{group}/{stemA}.gtpro.species_depth.tsv",
    species_correlation=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.gene{centroid}.spgc.species_correlation.tsv",
    species_gene=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.gene{centroid}.spgc.species_gene.list",
    strain_thresholds=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.gene{centroid}.spgc-{spgc_params}.strain_gene_threshold.tsv",
    gene_annotations=f"ref/midasdb_uhgg_gene_annotations/sp-{species}.gene{centroid}_annotations.tsv",
    raw_gene_depth=f"data/group/{group}/species/sp-{species}/{stemA}.gene{centroid}.depth.nc",
    reference_copy_number=f"ref/midasdb_uhgg_pangenomes/{species}/gene{centroid}.reference_copy_number.nc",
    cluster_info=f"ref/midasdb_uhgg/pangenomes/{species}/cluster_info.txt",
    gtpro_reference_genotype=f"data/species/sp-{species}/gtpro_ref.mgtp.nc",
))

path_exists = {}
for p in path:
    path_exists[path[p]] = os.path.exists(path[p])

assert all(path_exists.values()), '\n'.join(["Missing files:"] + [p for p in path_exists if not path_exists[p]])

In [None]:
path['flag']

In [None]:
reference_copy_number = xr.load_dataarray(path['reference_copy_number'])
reference_copy_number

In [None]:
gene_count = reference_copy_number.sum("gene_id")

In [None]:
trim_quantile = 0.25
threshold = 0.95

In [None]:
low_q, high_q = gene_count.quantile([trim_quantile, 1 - trim_quantile])

In [None]:
plt.hist(gene_count, bins=50)
plt.axvline(low_q, lw=1, linestyle='--', color='k')
plt.axvline(high_q, lw=1, linestyle='--', color='k')

In [None]:
inter_quantile_range_genomes = (gene_count > low_q) & (gene_count < high_q)

In [None]:
frac_single_copy = (reference_copy_number.sel(genome_id=inter_quantile_range_genomes) == 1).mean("genome_id")

In [None]:
frac_single_copy

In [None]:
plt.hist(frac_single_copy, bins=np.linspace(0.8, 1, num=101))
plt.axvline(threshold, lw=1, linestyle='--', color='k')
plt.yscale('log')

In [None]:
single_copy_genes = idxwhere((frac_single_copy > threshold).to_series())

In [None]:
len(single_copy_genes)