### Preamble

#### Project Template

In [None]:
%load_ext autoreload

In [None]:
import os as _os
_os.chdir(_os.environ['PROJECT_ROOT'])
_os.path.realpath(_os.path.curdir)

#### Imports

In [None]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
from lib.pandas_util import idxwhere, aligned_index, align_indexes, invert_mapping
import matplotlib as mpl
import lib.plot
import statsmodels as sm
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm
import subprocess
from tempfile import mkstemp
import time
import subprocess
from itertools import chain
import os
from itertools import product
from mpl_toolkits.axes_grid1 import make_axes_locatable


In [None]:
import sfacts as sf

In [None]:
import lib.thisproject.data

#### Set Style

In [None]:
sns.set_context('talk')
plt.rcParams['figure.dpi'] = 50

## Set Parameters / Load Initial Data

In [None]:
path_patterns = {}
path_params = {}

In [None]:
# Add new parameters.
path_params.update(dict(
    group_subset='ucfmt',
    group='ucfmt',
    stemA='r.proc',
))

# Add new patterns.
path_patterns.update(dict(
    species_taxonomy="ref/gtpro/species_taxonomy_ext.tsv",
    all_species_depth_subset="data/group/{group_subset}/{stemA}.gtpro.species_depth.tsv",
    all_species_depth="data/group/{group}/{stemA}.gtpro.species_depth.tsv",
    midasdb_genomes="ref/uhgg_genomes_all_4644.tsv",
    strain_genomes="meta/genome.tsv",
))

# This part is generic and should be run after ever new batch of path_patterns and path_params is added.
path = {k: path_patterns[k].format(**path_params) for k in path_patterns}
_path_exists = {}
for p in path:
    _path_exists[path[p]] = os.path.exists(path[p])
assert all(_path_exists.values()), '\n'.join(["Missing files:"] + [p for p in _path_exists if not _path_exists[p]])

### Species Abundance

In [None]:
species_depth = lib.thisproject.data.load_species_depth(path['all_species_depth'])
species_depth_subset = lib.thisproject.data.load_species_depth(path['all_species_depth_subset'])
rabund = species_depth.apply(lambda x: x / x.sum(), axis=1)
rabund_subset = species_depth_subset.apply(lambda x: x / x.sum(), axis=1)

n_species = 40
top_species = (rabund_subset > 1e-5).sum().sort_values(ascending=False).head(n_species).index

fig, axs = plt.subplots(n_species, figsize=(10, 0.5 * n_species), sharex=True, sharey=True)

bins = np.logspace(-7, 0, num=51)

for species_id, ax in zip(top_species, axs):
    ax.hist(rabund_subset[species_id], bins=bins, alpha=0.7)
    ax.hist(rabund[species_id], bins=bins, alpha=0.7)
    ax.set_xscale('log')
    prevalence = (rabund_subset[species_id] > 1e-5).mean()
    ax.set_title("")
    # ax.set_xticks()
    # ax.set_yticks()
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.patch.set_alpha(0.0)
    for spine in ['left', 'right', 'top', 'bottom']:
        ax.spines[spine].set_visible(False)
    ax.annotate(f'{species_id} ({prevalence:0.0%})', xy=(0.05, 0.1), ha='left', xycoords="axes fraction")
    ax.set_xlim(left=1e-7)
    ax.set_ylim(top=300)
    
ax.xaxis.set_visible(True)
ax.spines['bottom'].set_visible(True)

fig.subplots_adjust(hspace=-0.75)

In [None]:
n_species = 40
top_species = (species_depth_subset > 1e-3).sum().sort_values(ascending=False).head(n_species).index

fig, axs = plt.subplots(n_species, figsize=(10, 0.5 * n_species), sharex=True, sharey=True)

bins = np.logspace(-3, 4, num=51)

for species_id, ax in zip(top_species, axs):
    ax.hist(species_depth_subset[species_id], bins=bins, alpha=0.7)
    ax.hist(species_depth[species_id], bins=bins, alpha=0.7)
    ax.set_xscale('log')
    prevalence = (species_depth_subset[species_id] > 1e-5).mean()
    ax.set_title("")
    # ax.set_xticks()
    # ax.set_yticks()
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.patch.set_alpha(0.0)
    for spine in ['left', 'right', 'top', 'bottom']:
        ax.spines[spine].set_visible(False)
    ax.annotate(f'{species_id} ({prevalence:0.0%})', xy=(0.05, 0.1), ha='left', xycoords="axes fraction")
    ax.set_xlim(left=1e-4)
    ax.set_ylim(top=300)
    
ax.xaxis.set_visible(True)
ax.spines['bottom'].set_visible(True)

fig.subplots_adjust(hspace=-0.75)

In [None]:
n_species = 40
second_species = (rabund_subset > 1e-5).sum().sort_values(ascending=False).head(n_species * 2).tail(n_species).index

fig, axs = plt.subplots(n_species, figsize=(10, 0.5 * n_species), sharex=True, sharey=True)

bins = np.logspace(-7, 0, num=51)

for species_id, ax in zip(second_species, axs):
    ax.hist(rabund_subset[species_id], bins=bins, alpha=0.7)
    ax.hist(rabund[species_id], bins=bins, alpha=0.7)
    ax.set_xscale('log')
    prevalence = (rabund_subset[species_id] > 1e-5).mean()
    ax.set_title("")
    # ax.set_xticks()
    # ax.set_yticks()
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.patch.set_alpha(0.0)
    for spine in ['left', 'right', 'top', 'bottom']:
        ax.spines[spine].set_visible(False)
    ax.annotate(f'{species_id} ({prevalence:0.0%})', xy=(0.05, 0.1), ha='left', xycoords="axes fraction")
    ax.set_xlim(left=1e-7)
    ax.set_ylim(top=300)
    
ax.xaxis.set_visible(True)
ax.spines['bottom'].set_visible(True)

fig.subplots_adjust(hspace=-0.75)

In [None]:
n_species = 40
second_species = (species_depth_subset > 1e-3).sum().sort_values(ascending=False).head(n_species * 2).tail(n_species).index

fig, axs = plt.subplots(n_species, figsize=(10, 0.5 * n_species), sharex=True, sharey=True)

bins = np.logspace(-3, 4, num=51)

for species_id, ax in zip(second_species, axs):
    ax.hist(species_depth_subset[species_id], bins=bins, alpha=0.7)
    ax.hist(species_depth[species_id], bins=bins, alpha=0.7)
    ax.set_xscale('log')
    prevalence = (species_depth_subset[species_id] > 1e-5).mean()
    ax.set_title("")
    # ax.set_xticks()
    # ax.set_yticks()
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.patch.set_alpha(0.0)
    for spine in ['left', 'right', 'top', 'bottom']:
        ax.spines[spine].set_visible(False)
    ax.annotate(f'{species_id} ({prevalence:0.0%})', xy=(0.05, 0.1), ha='left', xycoords="axes fraction")
    ax.set_xlim(left=1e-4)
    ax.set_ylim(top=300)
    
ax.xaxis.set_visible(True)
ax.spines['bottom'].set_visible(True)

fig.subplots_adjust(hspace=-0.75)

In [None]:
sns.clustermap(species_depth_subset + 1e-5, norm=mpl.colors.PowerNorm(1/5), metric='cosine')

In [None]:
species_depth_corr = pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(species_depth_subset.T, metric='cosine')), index=species_depth_subset.columns, columns=species_depth_subset.columns)

In [None]:
sns.clustermap(1 - species_depth_corr, figsize=(20, 20))

#### !!!! Set Focal Species

In [None]:
species = '100022'  # '100236'

species_taxonomy = lib.thisproject.data.load_species_taxonomy(path["species_taxonomy"])
species_taxonomy.loc[species]

##### ^^^^^^

### Set Ground-truth Reference Strain

In [None]:
strain_genome = pd.read_table(path["strain_genomes"], dtype='str')
strain_genome[strain_genome.species_id == species]

In [None]:
strain_genome_ids = strain_genome[strain_genome.species_id == species].genome_id
print(strain_genome_ids)
strain_genome_id = strain_genome_ids.tolist()[0]
assert strain_genome_ids.shape[0] == 1

### Set Gene Family Clustering-level

In [None]:
centroid = 75

In [None]:
# Add new parameters.
path_params.update(dict(
    centroid=centroid,
    species=species,
    strain_genome_id=strain_genome_id,
))

# Add new patterns.
path_patterns.update(dict(
    strain_cds_length='data/species/sp-{species}/genome/{strain_genome_id}.prodigal-single.cds.nlength.tsv',
    strain_x_uhgg_bitscore_ratio='data/species/sp-{species}/genome/{strain_genome_id}.midas_uhgg_pangenome_new-blastn.bitscore_ratio-c{centroid}.tsv',
))

# This part is generic and should be run after ever new batch of path_patterns and path_params is added.
path = {k: path_patterns[k].format(**path_params) for k in path_patterns}
_path_exists = {}
for p in path:
    _path_exists[path[p]] = os.path.exists(path[p])
assert all(_path_exists.values()), '\n'.join(["Missing files:"] + [p for p in _path_exists if not _path_exists[p]])

### !!!! Set other SPGD parameters and check paths`

In [None]:
# Add new parameters.
path_params.update(dict(
    stemB = 'sfacts-fit',
    # stemC = 'sfacts42-seed0',
    gene_params = f"99_new-v22-agg{centroid}",
    # thresh_params = 'thresh-corrq10-depth300',
    corr_thresh="100",
    depth_thresh="250",
    specgene_params='specgene-ref-t25-p95',
    # ss_params="all",
    ss_params="all",  # "xjin-deepest-n10",
    trnsfm=30,
))

# Add new patterns.
path_patterns.update(dict(
    flag="data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.gene{gene_params}.spgc_{specgene_params}_ss-{ss_params}_t-{trnsfm}_thresh-corr{corr_thresh}-depth{depth_thresh}.strain_files.flag",
    fit="data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.world.nc",
    # refit="data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.refit-{stemC}.world.nc",
    strain_correlation="data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.gene{gene_params}.spgc_{specgene_params}_ss-{ss_params}_t-{trnsfm}.strain_correlation.tsv",
    strain_depth_ratio="data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.gene{gene_params}.spgc_{specgene_params}_ss-{ss_params}_t-{trnsfm}.strain_depth_ratio.tsv",
    strain_fraction="data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.comm.tsv",
    species_gene_mean_depth="data/group/{group}/species/sp-{species}/{stemA}.gene{gene_params}.spgc_{specgene_params}.species_depth.tsv",
    species_gtpro_depth="data/group/{group}/{stemA}.gtpro.species_depth.tsv",
    # species_correlation="data/group/{group}/species/sp-{species}/{stemA}.gene{gene_params}.spgc_specgene-denovo2-t30-n500.species_correlation.tsv",
    species_gene="data/group/{group}/species/sp-{species}/{stemA}.gene{gene_params}.spgc_{specgene_params}.species_gene.list",
    # species_gene_denovo="data/group/{group}/species/sp-{species}/{stemA}.gene{gene_params}.spgc_specgene-denovo-n500.species_gene.list",
    # species_gene_denovo2="data/group/{group}/species/sp-{species}/{stemA}.gene{gene_params}.spgc_specgene-denovo2-t30-n500.species_gene.list",
    species_gene_reference="data/group/{group}/species/sp-{species}/{stemA}.gene{gene_params}.spgc_specgene-ref-t25-p95.species_gene.list",
    species_free_samples="data/group/{group}/species/sp-{species}/{stemA}.gene{gene_params}.spgc_specgene-ref-t25-p95.species_free_samples.list",
    strain_samples="data/group/{group}/species/sp-{species}/r.proc.gtpro.{stemB}.spgc_ss-{ss_params}.strain_samples.tsv",
    strain_thresholds="data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.gene{gene_params}.spgc_{specgene_params}_ss-{ss_params}_t-{trnsfm}_thresh-corr{corr_thresh}-depth{depth_thresh}.strain_gene_threshold.tsv",
    gene_annotations="data/species/sp-{species}/midasdb_uhgg_new.gene_annotations.tsv",
    # raw_gene_depth="data/group/{group}/species/sp-{species}/{stemA}.pangenome95.gene{centroid}_depth.nc",
    # norm_gene_depth="data/group/{group}/species/sp-{species}/{stemA}.gene99-mapq0-agg{centroid}.normed_depth2.nc",
    raw_gene_depth="data/group/{group}/species/sp-{species}/{stemA}.gene{gene_params}.depth2.nc",
    # raw_gene_depth="data/group/{group}/species/sp-{species}/{stemA}.gene{centroid}.normed_depth2.nc",
    reference_copy_number="data/species/sp-{species}/gene{centroid}_new.reference_copy_number.nc",
    cluster_info="ref/midasdb_uhgg_new/pangenomes/{species}/cluster_info.txt",
    gtpro_reference_genotype="data/species/sp-{species}/gtpro_ref.mgtp.nc",
    reference_strain_accuracy="data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.gene{gene_params}.spgc_{specgene_params}_ss-{ss_params}_t-{trnsfm}_thresh-corr{corr_thresh}-depth{depth_thresh}.{strain_genome_id}.uhgg-reconstruction_accuracy.tsv",
    # reference_strain_accuracy_depth_only="data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.gene{gene_params}.spgc_{specgene_params}_ss-{ss_params}_t-{trnsfm}_thresh-corr0-depth{depth_thresh}.{strain_genome_id}.gene_content_reconstruction_accuracy.tsv",
    # reference_strain_mapping_q0="data/group/{group}/species/sp-{species}/ALL_STRAINS.tiles-l100-o99.gene99-mapq0-agg{centroid}.depth2.nc",
    reference_strain_mapping_q0="data/group/{group}/species/sp-{species}/ALL_STRAINS.tiles-l100-o99.gene{gene_params}.depth2.nc",
    # reference_strain_mapping_q1="data/group/{group}/species/sp-{species}/ALL_STRAINS.tiles-l100-o99.gene99-mapq1-agg{centroid}.depth2.nc",
    # reference_strain_mapping_q2="data/group/{group}/species/sp-{species}/ALL_STRAINS.tiles-l100-o99.gene99-mapq2-agg{centroid}.depth2.nc",
    # reference_strain_mapping_q4="data/group/{group}/species/sp-{species}/ALL_STRAINS.tiles-l100-o99.gene99-mapq4-agg{centroid}.depth2.nc",
    # xjin_benchmarking="data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.gene{gene_params}.spgc_{specgene_params}_ss-xjin-{ss_params}_t-{trnsfm}_thresh-corr{corr_thresh}-depth{depth_thresh}.xjin_strain_summary.tsv",
))

# This part is generic and should be run after ever new batch of path_patterns and path_params is added.
path = {k: path_patterns[k].format(**path_params) for k in path_patterns}
_path_exists = {}
for p in path:
    _path_exists[path[p]] = os.path.exists(path[p])
assert all(_path_exists.values()), '\n'.join(["Missing files:"] + [p for p in _path_exists if not _path_exists[p]])

In [None]:
path["flag"]

#### ^^^^^^

### SFacts results

In [None]:
fit = sf.World.load(path['fit'])
print(fit.sizes)
np.random.seed(0)
position_ss = fit.random_sample(position=min(fit.sizes['position'], 1000)).position


fit_subset = fit.sel(sample=list(set(species_depth_subset.index) & set(fit.sample.values)))

# fuzzy_geno = sf.Genotype.load(path['fit'])  # FIXME: refit
# fuzzy_geno = sf.World.from_combined(fuzzy_geno, fit.metagenotype, fit.community)

sf.evaluation.metagenotype_error2(fit)[0]

#### Plotting

In [None]:
w0 = (
        fit
        .drop_low_abundance_strains(0.05)
        .sel(position=position_ss)
    )

sf.plot.plot_metagenotype(
    w0,
    # scaley=0.2,
    # scalex=0.3,
    row_linkage_func=lambda w: w.metagenotype.linkage("position"),
    col_linkage_func=lambda w: w.metagenotype.linkage("sample"),
)
# sf.plot.plot_depth(
#     fit.sel(position=position_ss),
#     # scaley=0.2, scalex=0.3,
#     row_linkage_func=lambda w: w.metagenotype.linkage("position"),
#     col_linkage_func=lambda w: w.community.linkage(),
# )
# sf.plot.plot_dominance(
#     fit.sel(position=position_ss),
#     # scaley=0.2, scalex=0.3,
#     row_linkage_func=lambda w: w.metagenotype.linkage("position"),
#     col_linkage_func=lambda w: w.community.linkage(),
# )
sf.plot.plot_community(
    w0,
    scaley=0.3,
    # scalex=0.3,
    col_linkage_func=lambda w: w.metagenotype.linkage("sample"),
    row_linkage_func=lambda w: w.genotype.linkage("strain"),
)
sf.plot.plot_genotype(
    w0,
    scaley=0.2,
    # scalex=0.3,
    col_linkage_func=lambda w: w.metagenotype.linkage("position"),
    row_linkage_func=lambda w: w.genotype.linkage("strain"),
)

# sf.plot.plot_genotype(
#     (
#         fit
#         .drop_low_abundance_strains(0.05)
#         .sel(position=position_ss)
#     ),
#     scaley=0.2,
#     # scalex=0.3,
#     col_linkage_func=lambda w: w.metagenotype.linkage("position"),
#     row_linkage_func=lambda w: w.genotype.linkage("strain"),
# )

#### Identify focal sfacts strain

In [None]:
print(fit_subset.community.mean("sample").to_series().sort_values(ascending=False).head(5))
top_inferred_strain = fit_subset.community.mean("sample").to_series().idxmax()

if not fit_subset.community.mean("sample").sel(strain=top_inferred_strain) > 0.9:
    print("WARNING: No dominant strain found in fit subset.")

In [None]:
path['reference_strain_accuracy']

In [None]:
print(pd.read_table(path['reference_strain_accuracy'], index_col=0).sort_values('f1', ascending=False).loc[top_inferred_strain, ['precision', 'recall', 'f1']])
pd.read_table(path['reference_strain_accuracy'], index_col=0).sort_values('f1', ascending=False).head(5)

### Gene Annotations

In [None]:
gene_cluster = pd.read_table(
    path["cluster_info"]
).set_index('centroid_99', drop=False).rename_axis(index='gene_id')
gene_annotation = pd.read_table(
    path["gene_annotations"],
    names=['locus_tag', 'ftype', 'length_bp', 'gene', 'EC_number', 'COG', 'product'],
    index_col='locus_tag',
).rename(columns=str.lower)

gene_meta = gene_cluster.loc[gene_cluster[f'centroid_{centroid}'].unique()].join(gene_annotation)

### Reference Strain Gene Matching

In [None]:
blastp_header_names = [
    'qseqid',
    'sseqid',
    'pident',
    'length',
    'mismatch',
    'gapopen',
    'qstart',
    'qend',
    'sstart',
    'send',
    'evalue',
    'bitscore'
]

In [None]:
orf_length = pd.read_table(path['strain_cds_length'], names=['orf', 'length'], index_col=['orf']).squeeze()

orf_x_midas = pd.read_table(path['strain_x_uhgg_bitscore_ratio'], index_col=['orf', 'gene']).squeeze()
# orf_x_midas = pd.read_table('data/species/sp-101380/genome/Ruminococcus-gnavus-ATCC-29149_MinIONHybrid.midas_uhgg_pangenome-blastp.bitscore_ratio-c75.tsv', index_col=['orf', 'gene']).squeeze()


# _strain_x_strain = (
#     pd.read_table(
#         path['strain_x_strain'],
#         names=blastp_header_names
#     )
# )

# _max_bitscore = _strain_x_strain.groupby(['qseqid']).bitscore.max()

# strain_x_uhgg = (
#     pd.read_table(
#         path['strain_x_uhgg'],
#         names=blastp_header_names
#     )
#     .assign(bitscore_ratio=lambda x: x.bitscore / x.qseqid.map(_max_bitscore))
#     .assign(sseq_centroid=lambda x: x.sseqid.map(gene_cluster[f'centroid_{centroid}']))
# )

# best_uhgg_hit = strain_x_uhgg.groupby('qseqid').apply(lambda d: d.sort_values('bitscore').iloc[-1]).groupby('sseq_centroid').bitscore_ratio.max()

In [None]:
# orf_x_midas = strain_x_uhgg.groupby(['qseqid', 'sseq_centroid']).bitscore_ratio.max()


bins = np.linspace(0, 1)
plt.hist(orf_x_midas.unstack(fill_value=0).max(0), bins=bins, density=True)
plt.hist(orf_x_midas.unstack(fill_value=0).max(1), bins=bins, density=True, alpha=0.5)
plt.yscale('log')
None

In [None]:
(orf_x_midas.unstack().astype(float) > 0.95).sum(1).value_counts().sort_index()

### Strain-specific correlations/depth

In [None]:
strain_corr = pd.read_table(path["strain_correlation"], index_col=['gene_id', 'strain']).squeeze().unstack('strain', fill_value=0)
strain_depth = pd.read_table(
    path["strain_depth_ratio"],
    index_col=['gene_id', 'strain']
).squeeze().unstack()
# strain_corr, strain_depth = align_indexes(*align_indexes(strain_corr, strain_depth), axis="columns")

### Strain Metadata

In [None]:
strain_thresholds = (
    pd.read_table(path["strain_thresholds"], index_col='strain')
    .rename(columns=dict(
        # correlation_strict='corr_threshold_strict',
        correlation='corr_threshold',
        # correlation_lenient='corr_threshold_lenient',
        depth_high='depth_thresh_high',
        depth_low='depth_thresh_low',
    ))
)

_strain_meta = (
    strain_thresholds
    .join(fit.genotype.entropy().to_series().rename('genotype_entropy'))
    # .join(refit.genotype.entropy().to_series().rename('genotype_refit_entropy'))
    # .join(fit.metagenotype.entropy().to_series().rename('metagenotype_entropy').groupby(sample_to_strain).mean().rename(int))
    # .join(strain_to_sample_list.apply(len).rename('num_samples'))
    # .join(species_depth.apply(lambda x: x**(1)).groupby(sample_to_strain).std().rename('depth_stdev').rename(int))
    # .join(species_depth.apply(lambda x: x**(1)).groupby(sample_to_strain).max().rename('depth_max').rename(int))
    # .join(species_depth.apply(lambda x: x**(1)).groupby(sample_to_strain).sum().rename('depth_sum').rename(int))
    # .assign(power_index=lambda x: (x.depth_stdev * np.sqrt(x.num_samples)).fillna(0))
)
strain_meta = _strain_meta


# power_index_thresh = 5
# genotype_entropy_thresh = 0.2
# genotype_refit_entropy_thresh = 1.0

# high_power_strain_list = idxwhere(
#     (strain_meta.power_index > power_index_thresh)
#     & (strain_meta.genotype_entropy < genotype_entropy_thresh)
#     & (strain_meta.genotype_refit_entropy < genotype_refit_entropy_thresh)
# )
# print(len(high_power_strain_list))
# highest_power_strain_list = strain_meta.sort_values('power_index', ascending=False).head(3).index

# plt.scatter(strain_meta.power_index, strain_meta.corr_threshold, c=strain_meta.genotype_refit_entropy, alpha=0.5)
# plt.axvline(power_index_thresh, lw=1, linestyle='--', color='k')
# plt.colorbar()
# plt.xscale('log')

strain_meta.loc[[top_inferred_strain]]

In [None]:
species_corr = pd.read_table(path["species_correlation"], names=['sample', 'correlation'], index_col='sample').squeeze()

### SPGC Species Genes

In [None]:
with open(path["species_gene"]) as f:
    species_gene_hit = [line.strip() for line in f]
    
# with open(path["species_gene_denovo"]) as f:
#     species_gene_denovo_hit = [line.strip() for line in f]

# with open(path["species_gene_denovo2"]) as f:
# # with open("data/group/xjin_hmp2/species/sp-100203/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts37-s85-seed0.gene99-v22-agg75.spgc.species_gene2-n500.list") as f:
#     species_gene_denovo_hit2 = [line.strip() for line in f]

with open(path["species_gene_reference"]) as f:
    species_gene_reference_hit = [line.strip() for line in f]

### Strain gene definition

In [None]:
td0 = xr.load_dataarray('data/group/xjin_hmp2/species/sp-100022/ALL_STRAINS.tiles-l100-o99.gene99-v22-agg75.depth2.nc').sel(sample=f'sp-{species}/genome/{strain_genome_id}').to_series()
td1 = xr.load_dataarray('data/group/xjin_hmp2/species/sp-100022/ALL_STRAINS.tiles-l100-o99.gene99_new-v22-agg75.depth2.nc').sel(sample=f'sp-{species}/genome/{strain_genome_id}').to_series()

In [None]:
reference_strain_tile_depth_q0 = xr.load_dataarray(path['reference_strain_mapping_q0']).sel(sample=f'sp-{species}/genome/{strain_genome_id}')
# reference_strain_tile_depth_q1 = xr.load_dataarray(path['reference_strain_mapping_q1']).sel(sample=f'sp-{species}/genome/{strain_genome_id}')
# reference_strain_tile_depth_q2 = xr.load_dataarray(path['reference_strain_mapping_q2']).sel(sample=f'sp-{species}/genome/{strain_genome_id}')
# reference_strain_tile_depth_q4 = xr.load_dataarray(path['reference_strain_mapping_q4']).sel(sample=f'sp-{species}/genome/{strain_genome_id}')

bins = np.logspace(-1, 3, num=100)
plt.hist(reference_strain_tile_depth_q0 + 1e-1, bins=bins, alpha=0.5)
plt.hist(reference_strain_tile_depth_q0.reindex(gene_id=species_gene_hit, fill_value=0) + 1e-1, bins=bins, alpha=0.5)
# plt.hist(reference_strain_tile_depth_q1, bins=np.logspace(-3, 3, num=50), alpha=0.5)
# plt.hist(reference_strain_tile_depth_q2, bins=np.logspace(-3, 3, num=50), alpha=0.5)
# plt.hist(reference_strain_tile_depth_q4, bins=np.logspace(-3, 3, num=50), alpha=0.5)
# plt.yscale('log')
plt.xscale('log')
plt.yscale('log')
plt.axvline(200, lw=1, linestyle='-', color='k')
plt.axvline(50, lw=1, linestyle='--', color='k')

### !!!! Per-gene Strain Corr/Depth

In [None]:
_strain = top_inferred_strain

depth_threshold = strain_meta.depth_thresh_low.loc[_strain]
corr_threshold = strain_meta.corr_threshold.loc[_strain]
# corr_threshold = 0.95  # Set manually, but this could/should be the automatically selected threshold.
# depth_threshold = 0.2  # Set manually, but this could/should be the automatically selected threshold.
bitscore_threshold = 0.95
tile_depth_threshold = 30

strain_scores = (
    pd.DataFrame(dict(
        bitscore_ratio=orf_x_midas.unstack(fill_value=0).max(),
        strain_corr=strain_corr[_strain],
        strain_depth=strain_depth[_strain],
        species_corr=species_corr,
        tile_depth_q0=reference_strain_tile_depth_q0.to_series().reindex(strain_corr.index, fill_value=0),
        # tile_depth_q1=reference_strain_tile_depth_q1.to_series().reindex(strain_corr.index, fill_value=0),
        # tile_depth_q2=reference_strain_tile_depth_q2.to_series().reindex(strain_corr.index, fill_value=0),
        # tile_depth_q4=reference_strain_tile_depth_q4.to_series().reindex(strain_corr.index, fill_value=0),
        # strain_corr_q=strain_corr_q[_strain],
        # strain_depth_q=strain_depth_q[_strain],
    ))
    .fillna(0)
    .assign(
        bitscore_hit=lambda x: x.bitscore_ratio >= bitscore_threshold,
        not_bitscore_hit=lambda x: x.bitscore_ratio < bitscore_threshold,
        tile_depth_hit=lambda x: x.tile_depth_q0 >= tile_depth_threshold,
        not_tile_depth_hit=lambda x: x.tile_depth_q0 < tile_depth_threshold,
        depth_hit=lambda x: (x.strain_depth > depth_threshold),
        corr_and_depth_hit=lambda x: (x.strain_corr > corr_threshold) & (x.strain_depth > depth_threshold),
        species_gene=lambda x: x.index.to_series().isin(species_gene_hit),
        species_gene_denovo=lambda x: x.index.to_series().isin(species_gene_denovo_hit),
        species_gene_denovo2=lambda x: x.index.to_series().isin(species_gene_denovo_hit2),
        species_gene_reference=lambda x: x.index.to_series().isin(species_gene_reference_hit),
        corr_complement=lambda x: 1 - x.strain_corr,
        # log_tile_depth=lambda x: np.log10(x.tile_depth + 1e-4),
        dummy=False,
        gene_length=gene_cluster.groupby(f'centroid_{centroid}').centroid_99_length.mean(),
    )
    .sort_values('bitscore_ratio')
)

In [None]:
indicator_list = [
        'bitscore_hit',
        'not_bitscore_hit',
        'tile_depth_hit',
        'not_tile_depth_hit',
        'species_gene',
        'species_gene_denovo',
        'species_gene_denovo2',
        'species_gene_reference',
        'gene_length',
        'tile_depth_q0',
        # 'tile_depth_q2',
    ]

fig, axs = lib.plot.subplots_grid(ncols=2, naxes=len(indicator_list), ax_width=6, ax_height=4, sharex=True, sharey=True)

for ax, c in zip(
    axs.flatten(),
    indicator_list,
):
    
    cax = make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05)
    artist = ax.scatter(
        'corr_complement',
        'strain_depth',
        data=strain_scores.sort_values(c),
        s=1,
        c=c,
        alpha=0.9,
        cmap='rainbow',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4),
    )
    cbar = fig.colorbar(artist, cax=cax)
    cbar.solids.set_alpha(1.0)
    ax.axhline(depth_threshold, lw=1, linestyle='--')
    ax.axhline(1, xmin=0., xmax=0.5, lw=1, linestyle='--', color='k')
    ax.axvline(1 - corr_threshold, lw=1, linestyle='--')
    # ax.set_xscale('symlog', linthresh=1e-1)
    ax.set_title(c)
    # TODO: xscale logit?
    ax.set_yscale('symlog', linthresh=1e-2)
    ax.set_ylim(bottom=0)
ax.invert_xaxis()
ax.set_xlabel('correlation')
ax.set_ylabel('depth ratio')

In [None]:
indicator_list = [
        # 'bitscore_hit',
        # 'not_bitscore_hit',
        'tile_depth_hit',
        # 'not_tile_depth_hit',
        # 'species_gene',
        # 'species_gene_denovo',
        # 'species_gene_denovo2',
        # 'species_gene_reference',
        # 'gene_length',
        # 'tile_depth_q0',
        # 'tile_depth_q2',
    ]

fig, axs = lib.plot.subplots_grid(ncols=1, naxes=len(indicator_list), ax_width=6, ax_height=4, sharex=True, sharey=True)

for ax, c in zip(
    axs.flatten(),
    indicator_list,
):
    
    # cax = make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05)
    artist = ax.scatter(
        'corr_complement',
        'strain_depth',
        data=strain_scores.sort_values(c),
        s=1,
        c=c,
        alpha=0.9,
        cmap='copper',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4),
    )
    # cbar = fig.colorbar(artist, cax=cax)
    # cbar.solids.set_alpha(1.0)
    # ax.axhline(depth_threshold, lw=1, linestyle='--')
    ax.axhline(1, lw=1, linestyle='--', color='k')
    # ax.axvline(1 - corr_threshold, lw=1, linestyle='--')
    # ax.set_xscale('symlog', linthresh=1e-1)
    # ax.set_title(c)
    # TODO: xscale logit?
    ax.set_yscale('symlog', linthresh=1e-2)
    ax.set_ylim(bottom=0)
ax.invert_xaxis()
ax.set_xlabel('correlation')
ax.set_ylabel('depth ratio')

#### ^^^^^^

## Selected Gene Refinement

In [None]:
from sklearn.mixture import GaussianMixture, BayesianGaussianMixture
from sklearn.cluster import DBSCAN, AgglomerativeClustering, OPTICS

In [None]:
weak_hits = idxwhere((strain_scores.strain_corr > 0.1) & (strain_scores.strain_depth > 0.1))

In [None]:
print(len(weak_hits))

In [None]:
bin_size = 500
x = strain_scores.loc[weak_hits].strain_corr.sort_values(ascending=False).to_frame().assign(
    delta=lambda x: [np.nan] + list(x.strain_corr.values[:-1] - x.strain_corr.values[1:]),
    rolling_delta=lambda x: x.delta.rolling(bin_size, center=True).mean(),
)
# )

fig, ax = plt.subplots()

ax.scatter('strain_corr', 'rolling_delta', data=x)
ax2.set_yscale('log')

# ax2 = ax.twinx()
# ax2.scatter(x.rolling_delta.values, color='tab:orange')

thresh_idx = x.rolling_delta.iloc[bin_size*2:].idxmax()
spgc2_corr_thresh = x.loc[thresh_idx].strain_corr
n_above_thresh = (x.strain_corr > spgc2_corr_thresh).sum()
# ax.axvline(n_above_thresh, lw=1, color='k', linestyle='--')
print(n_above_thresh, spgc2_corr_thresh)

In [None]:
spgc2_depth_thresh = 0.25

In [None]:
indicator_list = [
        'bitscore_hit',
        'not_bitscore_hit',
        'tile_depth_hit',
        'not_tile_depth_hit',
        'species_gene',
        'species_gene_denovo',
        'species_gene_denovo2',
        'species_gene_reference',
        'gene_length',
        'tile_depth_q0',
        # 'tile_depth_q2',
    ]

fig, axs = lib.plot.subplots_grid(ncols=2, naxes=len(indicator_list), ax_width=6, ax_height=4, sharex=True, sharey=True)

for ax, c in zip(
    axs.flatten(),
    indicator_list,
):
    
    cax = make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05)
    artist = ax.scatter(
        'corr_complement',
        'strain_depth',
        data=strain_scores.sort_values(c),
        s=1,
        c=c,
        alpha=0.9,
        cmap='rainbow',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4),
    )
    cbar = fig.colorbar(artist, cax=cax)
    cbar.solids.set_alpha(1.0)
    ax.axhline(spgc2_depth_thresh, lw=1, linestyle='--')
    ax.axhline(1, xmin=0., xmax=0.5, lw=1, linestyle='--', color='k')
    ax.axvline(1 - spgc2_corr_thresh, lw=1, linestyle='--')
    ax.set_xscale('symlog', linthresh=1e-3)
    ax.set_title(c)
    # TODO: xscale logit?
    ax.set_yscale('symlog', linthresh=1e-2)
    ax.set_ylim(bottom=0)
ax.invert_xaxis()
ax.set_xlabel('correlation')
ax.set_ylabel('depth ratio')

### ^^^^^^

### Experiment: How to pick cutoffs (Laplace Smoothing)

#### 2D Thresholding

### Exploration: What is a "reference gene hit"? (tile depth)

In [None]:
(orf_x_midas.unstack().astype(float) > bitscore_threshold).T.reindex(idxwhere(strain_scores.tile_depth_q0 > tile_depth_threshold), fill_value=0).sum().value_counts().sort_index()

In [None]:
bins = [0] + list(np.logspace(-2, 1, num=100))
plt.hist(strain_scores[lambda x: x.species_gene_reference].strain_depth.sort_values(), bins=bins)
plt.xscale('symlog', linthresh=1e-4)
plt.yscale('log')

print(strain_scores[lambda x: x.species_gene_reference].strain_depth.sort_values().head())

In [None]:
bitscore_thresh = 0.95

d = pd.DataFrame(dict(
    q0=reference_strain_tile_depth_q0.to_series().reindex(strain_corr.index, fill_value=0),
    # q1=reference_strain_tile_depth_q1.to_series().reindex(strain_corr.index, fill_value=0),
    # q2=reference_strain_tile_depth_q2.to_series().reindex(strain_corr.index, fill_value=0),
    # q4=reference_strain_tile_depth_q4.to_series().reindex(strain_corr.index, fill_value=0),
    bitscore_ratio=orf_x_midas.unstack(fill_value=0).max()
))

thresh_list = [0.99, 0.95, 0.8, 0.5, 0.3, 0.1]

fig, axs = lib.plot.subplots_grid(ncols=2, naxes=len(thresh_list), ax_width=5, ax_height=3.5, sharex=True, sharey=True)
bins = np.logspace(-2, 3, num=50)

for thresh, ax in zip(thresh_list, axs.flatten()):
# for q, ax in zip(['q2'], axs):
    ax.hist('q0', data=d[(d.bitscore_ratio < thresh)], bins=bins, alpha=0.5, label='not-matched')
    ax.hist('q0', data=d[(d.bitscore_ratio >= thresh)], bins=bins, alpha=0.5, label='matched')
    ax.set_title(f'bitscore_ratio >= {thresh}')
# plt.yscale('log')
ax.set_xscale('log')
ax.set_xticks(np.logspace(-3, 3, num=7))
axs[0,0].legend()
# fig.tight_layout()

In [None]:
(orf_x_midas.unstack().astype(float) > 0.95).sum(1).value_counts().sort_index()

In [None]:
strain_scores[(strain_scores.bitscore_ratio > 0.95)][['bitscore_ratio', 'strain_corr', 'strain_depth', 'species_gene_reference', 'tile_depth_q0', 'gene_length']].sort_values('strain_depth').head(5)

In [None]:
d = strain_scores.assign(bitscore_ratio_ratio=orf_x_midas.unstack().fillna(0).divide(orf_x_midas.unstack().fillna(0).max(1), axis=0).max())

fig, ax = plt.subplots(figsize=(10, 5))
plt.scatter('corr_complement', 'tile_depth_q0', data=d, s=1, c='bitscore_ratio', cmap='viridis_r')
plt.colorbar()
plt.yscale('symlog', linthresh=1e-0)
plt.xscale('log')
ax.invert_xaxis()

In [None]:
d = strain_scores

fig, ax = plt.subplots()
plt.scatter('tile_depth_q0', 'strain_depth', c='strain_corr', data=d[lambda x: x.bitscore_ratio > 0.95], s=1, norm=mpl.colors.PowerNorm(1, vmin=0, vmax=1))
plt.colorbar()
plt.yscale('symlog', linthresh=1e-2)
plt.xscale('symlog', linthresh=1e-2)

fig, ax = plt.subplots()
plt.scatter('tile_depth_q0', 'strain_depth', c='strain_corr', data=d[lambda x: x.bitscore_ratio <= 0.95], s=1, norm=mpl.colors.PowerNorm(1, vmin=0, vmax=1))
plt.colorbar()
plt.yscale('symlog', linthresh=1e-2)
plt.xscale('symlog', linthresh=1e-2)

In [None]:
top_inferred_strain

In [None]:
pd.read_table(path['reference_strain_accuracy'], index_col=0).sort_values('f1', ascending=False).head()

### Gene Length Bias

In [None]:
d = strain_scores[lambda x: (x.bitscore_hit) & (x.strain_depth > 0)].join(gene_cluster.centroid_99_length).assign(
        log_centroid_99_length=lambda x: np.log10(x.centroid_99_length / 3),
        log_strain_depth=lambda x: np.log10(x.strain_depth),
    )

sns.regplot(
    x='log_centroid_99_length',
    y='log_strain_depth',
    data=d,
    lowess=True,
    scatter_kws=dict(s=2),
)
plt.axhline(0, lw=1, linestyle='--', color='k')
plt.ylim(-1.5, 1.5)
# plt.xscale('log')
# plt.yscale('log')

In [None]:
d = strain_scores[lambda x: (x.bitscore_hit) & (x.strain_depth > 0)].join(gene_cluster.centroid_99_length).assign(
        log_centroid_99_length=lambda x: np.log10(x.centroid_99_length / 3),
        log_strain_depth=lambda x: np.log10(x.strain_depth),
    )

sns.regplot(
    x='log_centroid_99_length',
    y='log_strain_depth',
    data=d[d.species_gene_reference],
    lowess=True,
    scatter_kws=dict(s=2),
)
plt.axhline(0, lw=1, linestyle='--', color='k')
plt.ylim(-1.5, 1.5)
# plt.xscale('log')
# plt.yscale('log')

### Comparison to Reference Strains

In [None]:
ref_geno = sf.Metagenotype.load(path['gtpro_reference_genotype'])
strain_geno = sf.Metagenotype.load(f"data/species/sp-{species}/strain_genomes.gtpro.mgtp.nc")
# ref_hits = (xr.load_dataarray(path['reference_copy_number']) >= 1).to_series().unstack('gene_id').T

In [None]:
m = sf.Metagenotype.concat(dict(ref=ref_geno, strain=strain_geno), dim='sample')
print(m.sizes)

In [None]:
strain_genome_id

In [None]:
sf.plot_metagenotype(m.sel(position=position_ss))

## Benchmark Accuracy Against Other Tools

### Ground-truth

In [None]:
eggnog_names = "query seed_ortholog evalue score eggNOG_OGs max_annot_lvl COG_category Description Preferred_name GOs EC KEGG_ko KEGG_Pathway KEGG_Module KEGG_Reaction KEGG_rclass BRITE KEGG_TC CAZy BiGG_Reaction PFAMs".split(" ")
_path = f'data/species/sp-{species}/genome/{strain_genome_id}.prodigal-single.cds.emapper.d/proteins.emapper.annotations'
print(_path)
genome_eggnog = pd.read_table(_path, comment="#", names=eggnog_names, index_col="query").rename_axis(index="gene_id").replace({'-': np.nan})
genome_eggnog.info()

In [None]:
orf_x_ko = genome_eggnog.KEGG_ko.dropna().str.split(',').explode()

In [None]:
orf_x_cog = genome_eggnog.eggNOG_OGs.str.split(',').explode().str.split("@").str[0][lambda x: x.str.startswith('COG')].rename('cog')

In [None]:
orf_x_eggnog = genome_eggnog.eggNOG_OGs.str.split(',').explode().str.split("@").str[0].rename('eggnog')

In [None]:
orf_unique_ko_hit = orf_x_ko.value_counts()
print(len(orf_x_ko), len(orf_unique_ko_hit))
orf_unique_ko_hit

In [None]:
orf_unique_cog_hit = orf_x_cog.value_counts()
print(len(orf_x_cog), len(orf_unique_cog_hit))
orf_unique_cog_hit

### MIDAS UHGG Metadata

In [None]:
eggnog_names = "query seed_ortholog evalue score eggNOG_OGs max_annot_lvl COG_category Description Preferred_name GOs EC KEGG_ko KEGG_Pathway KEGG_Module KEGG_Reaction KEGG_rclass BRITE KEGG_TC CAZy BiGG_Reaction PFAMs".split(" ")
_path = f'data/species/sp-{species}/pangenome.centroids.emapper.d/proteins.emapper.annotations'
print(_path)
uhgg_eggnog = pd.read_table(_path, comment="#", names=eggnog_names, index_col="query").rename_axis(index="gene_id").replace({'-': np.nan})
uhgg_eggnog.info()

In [None]:
uhgg_x_ko = uhgg_eggnog.KEGG_ko.dropna().str.split(',').explode().rename('ko')
uhgg_x_cog = uhgg_eggnog.eggNOG_OGs.str.split(',').explode().str.split("@").str[0][lambda x: x.str.startswith('COG')].rename('cog')
uhgg_x_eggnog = uhgg_eggnog.eggNOG_OGs.str.split(',').explode().str.split("@").str[0].rename('eggnog')

### SPGC

In [None]:
spgc_hits = strain_scores[lambda x: x.corr_and_depth_hit][[]]
spgc_ko_hit = spgc_hits.join(uhgg_x_ko).dropna()
spgc_unique_ko_hit = spgc_ko_hit.ko.value_counts()
print(len(spgc_hits), len(spgc_ko_hit), len(spgc_unique_ko_hit))
spgc_unique_ko_hit

In [None]:
spgc_cog_hit = spgc_hits.join(uhgg_x_cog).dropna()
spgc_unique_cog_hit = spgc_cog_hit.cog.value_counts()
print(len(spgc_hits), len(spgc_cog_hit), len(spgc_unique_cog_hit))
spgc_unique_cog_hit

In [None]:
spgc_eggnog_hit = spgc_hits.join(uhgg_x_eggnog).dropna()
spgc_unique_eggnog_hit = spgc_eggnog_hit.eggnog.value_counts()
print(len(spgc_hits), len(spgc_eggnog_hit), len(spgc_unique_eggnog_hit))
spgc_unique_eggnog_hit

### SPGC Depth Only

In [None]:
depth_hits = strain_scores[lambda x: x.depth_hit][[]]
depth_ko_hit = depth_hits.join(uhgg_x_ko).dropna()
depth_unique_ko_hit = depth_ko_hit.ko.value_counts()
print(len(depth_hits), len(depth_ko_hit), len(depth_unique_ko_hit))
depth_unique_ko_hit

In [None]:
depth_cog_hit = depth_hits.join(uhgg_x_cog).dropna()
depth_unique_cog_hit = depth_cog_hit.cog.value_counts()
print(len(depth_hits), len(depth_cog_hit), len(depth_unique_cog_hit))
depth_unique_cog_hit

### SPGC threshold2

In [None]:
# spgc2_hits = pd.DataFrame([], index=spgc2_hit_list)
spgc2_hits = strain_scores[lambda x: (x.strain_corr >= spgc2_corr_thresh) & (x.strain_depth > spgc2_depth_thresh)][[]]
spgc2_ko_hit = spgc2_hits.join(uhgg_x_ko).dropna()
spgc2_unique_ko_hit = spgc2_ko_hit.ko.value_counts()
print(len(spgc2_hits), len(spgc2_ko_hit), len(spgc2_unique_ko_hit))
spgc2_unique_ko_hit

In [None]:
spgc2_cog_hit = spgc2_hits.join(uhgg_x_cog).dropna()
spgc2_unique_cog_hit = spgc2_cog_hit.cog.value_counts()
print(len(spgc2_hits), len(spgc2_cog_hit), len(spgc2_unique_cog_hit))
spgc2_unique_cog_hit

### SPGC threshold3

### PanPhlan on SPGC Depths

In [None]:
focal_xjin_sample = (
    pd.read_table(path["strain_samples"])
    [lambda x: x.strain == top_inferred_strain]
    ['sample']
    .to_frame()
    .join(species_depth[species].rename('depth'), on='sample')
    .sort_values('depth', ascending=False)
    ['sample']
    [lambda x: x.str.startswith('xjin_')]
    .values[0]
)

In [None]:
_path = f'data/group/xjin/species/sp-{species}/r.proc.pangenomes{path_params["gene_params"]}.panphlan_hit.tsv'
print(_path)
_panphlan2_hit = pd.read_table(_path).rename(columns={'Unnamed: 0': 'panphlan_gene_id'}).set_index('panphlan_gene_id')
if focal_xjin_sample not in _panphlan2_hit:
    print("FOCAL SAMPLE NOT FOUND")
    _focal_xjin_sample = _panphlan2_hit.columns[0]
else:
    _focal_xjin_sample = focal_xjin_sample
panphlan2_hit = _panphlan2_hit[lambda x: x[_focal_xjin_sample]==1][[]]

panphlan2_ko_hit = panphlan2_hit.join(uhgg_x_ko).dropna()
panphlan2_unique_ko_hit = panphlan2_ko_hit.ko.value_counts()
print(len(panphlan2_hit), len(panphlan2_ko_hit), len(panphlan2_unique_ko_hit))
panphlan2_unique_ko_hit

In [None]:
panphlan2_cog_hit = panphlan2_hit.join(uhgg_x_cog).dropna()
panphlan2_unique_cog_hit = panphlan2_cog_hit.cog.value_counts()
print(len(panphlan2_hit), len(panphlan2_cog_hit), len(panphlan2_unique_cog_hit))
panphlan2_unique_cog_hit

### StrainPanDA on SPGC Depths

In [None]:
_path = f'data/group/xjin/species/sp-{species}/r.proc.pangenomes{path_params["gene_params"]}.spanda-s2.strain_sample.csv'
print(_path)

_panphlan2_strain_sample = pd.read_csv(_path)
if focal_xjin_sample not in _panphlan2_strain_sample:
    print("FOCAL SAMPLE NOT FOUND")
    _focal_xjin_sample = _panphlan2_strain_sample.columns[0]
    
spanda2_strain_name = _panphlan2_strain_sample[_focal_xjin_sample].idxmax()
print(spanda2_strain_name)

In [None]:
_path = f'data/group/xjin/species/sp-{species}/r.proc.pangenomes{path_params["gene_params"]}.spanda-s2.genefamily_strain.csv'
print(_path)
spanda2_strain_name = 'strain2'  # This could be found by looking for the dominant strain in f'data/group/xjin_102395_subset/species/sp-{species}/r.proc.panphlan.spanda-s2.strain_sample.csv'
spanda2_hit = pd.read_csv(_path).astype(bool)[spanda2_strain_name][lambda x: x].to_frame()[[]]

In [None]:
spanda2_ko_hit = spanda2_hit.join(uhgg_x_ko).dropna()
spanda2_unique_ko_hit = spanda2_ko_hit.ko.value_counts()
print(len(spanda2_hit), len(spanda2_ko_hit), len(spanda2_unique_ko_hit))
spanda2_unique_ko_hit

In [None]:
spanda2_cog_hit = spanda2_hit.join(uhgg_x_cog).dropna()
spanda2_unique_cog_hit = spanda2_cog_hit.cog.value_counts()
print(len(spanda2_hit), len(spanda2_cog_hit), len(spanda2_unique_cog_hit))
spanda2_unique_cog_hit

### !!!! Compare SPGC and PanPhlan to Ground Truth

In [None]:
ground_truth = set(orf_unique_ko_hit.index)


for name, hit in dict(
    spgc=spgc_unique_ko_hit,
    spgc2=spgc2_unique_ko_hit,
    # # spgc3=spgc3_unique_ko_hit,
    depth=depth_unique_ko_hit,
    # # panphlan=panphlan_unique_ko_hit,
    panphlan2=panphlan2_unique_ko_hit,
    # # spanda=spanda_unique_ko_hit
    spanda2=spanda2_unique_ko_hit
).items():
    hit = set(hit.index)
    precision = len(hit & ground_truth) / len(hit)
    recall = len(hit & ground_truth) / len(ground_truth)
    f1 = sp.stats.hmean([precision, recall])
    print(name, len(ground_truth - hit), len(ground_truth & hit), len(hit - ground_truth))
    print(f"{name}, {precision:.1%}, {recall:.1%}, {f1:.1%}")

In [None]:
ground_truth = set(orf_unique_cog_hit.index)

for name, hit in dict(
    spgc=spgc_unique_cog_hit,
    spgc2=spgc2_unique_cog_hit,
    # # spgc3=spgc3_unique_cog_hit,
    depth=depth_unique_cog_hit,
    # # panphlan=panphlan_unique_cog_hit,
    panphlan2=panphlan2_unique_cog_hit,
    # # spanda=spanda_unique_cog_hit
    spanda2=spanda2_unique_cog_hit,
).items():
    hit = set(hit.index)
    precision = len(hit & ground_truth) / len(hit)
    recall = len(hit & ground_truth) / len(ground_truth)
    f1 = sp.stats.hmean([precision, recall])
    print(name, len(ground_truth - hit), len(ground_truth & hit), len(hit - ground_truth))
    print(f"{name}, {precision:.1%}, {recall:.1%}, {f1:.1%}")

In [None]:
ground_truth = set(idxwhere(strain_scores.tile_depth_hit))

for name, hit in dict(
    spgc=spgc_hits,
    spgc2=spgc2_hits,
    # # spgc3=spgc3_hits,
    depth=depth_hits,
    # # panphlan=panphlan_unique_cog_hit,
    panphlan2=panphlan2_hit,
    # # spanda=spanda_unique_cog_hit
    spanda2=spanda2_hit,
).items():
    hit = set(hit.index)
    precision = len(hit & ground_truth) / len(hit)
    recall = len(hit & ground_truth) / len(ground_truth)
    f1 = sp.stats.hmean([precision, recall])
    print(name, len(ground_truth - hit), len(ground_truth & hit), len(hit - ground_truth))
    print(f"{name}, {precision:.1%}, {recall:.1%}, {f1:.1%}")

#### ^^^^^^

## Survey Across xjin Species

In [None]:
ref_strains = pd.read_table('meta/genome.tsv', index_col='genome_id')[lambda x: ~x.genome_path.isna()]

species_strain_counts = ref_strains.value_counts('species_id')

_all_species_depth = species_depth


all_species_strain_accuracy = {}
for _species in tqdm(species_strain_counts.index):
    _strain = idxwhere(ref_strains.species_id == _species)[0]

    # Format paths to reflect _species.
    _path_params = path_params.copy()
    _path_params.update(species=_species, strain_genome_id=_strain)
    _path = {k: path_patterns[k].format(**_path_params) for k in path_patterns}
    
    # Check how many strains.
    _count = species_strain_counts[_species]
    if _count > 1:
        print(f"Species {_species} has {_count} strains.")
        continue
        
    # Some species missing from _all_species_depth.
    # FIXME: This doesn't make sense with the reindexing of _species_depth below...?
    if _species in _all_species_depth:
        _species_depth = _all_species_depth[_species]
    else:
        _species_depth = np.nan
    
    if not os.path.exists(_path['flag']):
        print(f"{_species} is missing flag file")
        continue
    if not os.path.exists(_path["fit"]):
        print(f"{_species} is missing fit file")
        continue
    if not os.path.exists(_path['reference_strain_accuracy']):
        print(f"{_species} is missing accuracy file")
        continue
    _species_gene_list = pd.read_table(_path["species_gene_reference"], names=['gene_id']).gene_id.tolist()
    _fit = sf.World.load(_path["fit"])
    _thresh = pd.read_table(_path["strain_thresholds"], index_col='strain')
    _accuracy = pd.read_table(_path['reference_strain_accuracy'], index_col='strain')
    _accuracy_depth_only = pd.read_table(_path['reference_strain_accuracy_depth_only'], index_col='strain')
    _top_strain = _fit.community.sel(sample=idxwhere(_fit.community.sample.to_series().str.startswith('xjin_'))).mean("sample").to_series().idxmax()
    if not _top_strain in _accuracy.index:
        print(f"{_species} {_strain} is missing accuracy info")
        continue
    with open(_path["species_free_samples"]) as f:
        _species_free_samples = [line.strip() for line in f]
    _sample_to_strain = pd.read_table(_path["strain_samples"], index_col=['sample']).strain
    _sample_list = idxwhere(_sample_to_strain == _top_strain)
    
    _accuracy = _accuracy.join(_accuracy_depth_only, rsuffix='_depth_only').loc[_top_strain]
    _accuracy['geno_entropy'] = _fit.genotype.entropy().to_series()[_top_strain]
    _accuracy['num_strain_samples'] = len(_sample_list)
    _accuracy['num_species_free_samples'] = len(_species_free_samples)
    _accuracy['num_hmp_samples'] = sum([not s.startswith('xjin_') for s in _sample_list])
    _accuracy['num_species_genes'] = len(_species_gene_list)
    _accuracy['depth_stdev'] = _species_depth.reindex(_sample_list, fill_value=0).std()
    _accuracy['depth_max'] = _species_depth.reindex(_sample_list, fill_value=0).max()
    _accuracy['depth_sum'] = _species_depth.reindex(_sample_list, fill_value=0).sum()
    try:
        _accuracy['corr_thresh'] = _thresh.loc[_top_strain].correlation
        _accuracy['depth_thresh'] = _thresh.loc[_top_strain].depth_low
    except KeyError as err:  # FIXME: This should next happen
        _accuracy['corr_thresh'] = np.nan
        _accuracy['depth_thresh'] = np.nan
        print(err)
    all_species_strain_accuracy[_species] = _accuracy
    
all_species_strain_accuracy = (
    pd.DataFrame(all_species_strain_accuracy).T
    # .rename(columns=dict(
    #     precision_depth_only_1to1='precision_1to1_depth_only',
    #     recall_depth_only_1to1='recall_1to1_depth_only',
    #     f1_depth_only_1to1='f1_1to1_depth_only',
    # ))
    .assign(
        power_index=lambda x: (x.depth_stdev * np.sqrt(x.num_strain_samples)).fillna(0)
    )
)


