### Preamble

#### Project Template

In [None]:
%load_ext autoreload

In [None]:
import os as _os
_os.chdir(_os.environ['PROJECT_ROOT'])
_os.path.realpath(_os.path.curdir)

#### Imports

In [None]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
from lib.pandas_util import idxwhere, aligned_index, align_indexes, invert_mapping
import matplotlib as mpl
import lib.plot
import statsmodels as sm
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm
import subprocess
from tempfile import mkstemp
import time
import subprocess
from itertools import chain
import os
from itertools import product
from mpl_toolkits.axes_grid1 import make_axes_locatable


In [None]:
import sfacts as sf

In [None]:
import lib.thisproject.data

#### Set Style

In [None]:
sns.set_context('talk')
plt.rcParams['figure.dpi'] = 50

## Set Parameters / Load Data

In [None]:
path_patterns = {}
path_params = {}

In [None]:
# Add new parameters.
path_params.update(dict(
    group_subset='xjin',
    group='xjin_hmp2',
    stemA='r.proc',
))

# Add new patterns.
path_patterns.update(dict(
    species_taxonomy="ref/gtpro/species_taxonomy_ext.tsv",
    all_species_depth_subset="data/group/{group_subset}/{stemA}.gtpro.species_depth.tsv",
    all_species_depth="data/group/{group}/{stemA}.gtpro.species_depth.tsv",
    midasdb_genomes="ref/uhgg_genomes_all_4644.tsv",
    strain_genomes="meta/genome.tsv",
))

# This part is generic and should be run after ever new batch of path_patterns and path_params is added.
path = {k: path_patterns[k].format(**path_params) for k in path_patterns}
_path_exists = {}
for p in path:
    _path_exists[path[p]] = os.path.exists(path[p])
assert all(_path_exists.values()), '\n'.join(["Missing files:"] + [p for p in path_exists if not path_exists[p]])

### Species Abundance

In [None]:
species_depth = lib.thisproject.data.load_species_depth(path['all_species_depth'])
species_depth_subset = lib.thisproject.data.load_species_depth(path['all_species_depth_subset'])
rabund = species_depth.apply(lambda x: x / x.sum(), axis=1)
rabund_subset = species_depth_subset.apply(lambda x: x / x.sum(), axis=1)

n_species = 40
top_species = (rabund_subset > 1e-5).sum().sort_values(ascending=False).head(n_species).index

fig, axs = plt.subplots(n_species, figsize=(10, 0.5 * n_species), sharex=True, sharey=True)

bins = np.logspace(-7, 0, num=51)

for species_id, ax in zip(top_species, axs):
    ax.hist(rabund_subset[species_id], bins=bins, alpha=0.7)
    ax.hist(rabund[species_id], bins=bins, alpha=0.7)
    ax.set_xscale('log')
    prevalence = (rabund_subset[species_id] > 1e-5).mean()
    ax.set_title("")
    # ax.set_xticks()
    # ax.set_yticks()
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.patch.set_alpha(0.0)
    for spine in ['left', 'right', 'top', 'bottom']:
        ax.spines[spine].set_visible(False)
    ax.annotate(f'{species_id} ({prevalence:0.0%})', xy=(0.05, 0.1), ha='left', xycoords="axes fraction")
    ax.set_xlim(left=1e-7)
    ax.set_ylim(top=300)
    
ax.xaxis.set_visible(True)
ax.spines['bottom'].set_visible(True)

fig.subplots_adjust(hspace=-0.75)

In [None]:
n_species = 40
top_species = (species_depth_subset > 1e-3).sum().sort_values(ascending=False).head(n_species).index

fig, axs = plt.subplots(n_species, figsize=(10, 0.5 * n_species), sharex=True, sharey=True)

bins = np.logspace(-3, 4, num=51)

for species_id, ax in zip(top_species, axs):
    ax.hist(species_depth_subset[species_id], bins=bins, alpha=0.7)
    ax.hist(species_depth[species_id], bins=bins, alpha=0.7)
    ax.set_xscale('log')
    prevalence = (species_depth_subset[species_id] > 1e-5).mean()
    ax.set_title("")
    # ax.set_xticks()
    # ax.set_yticks()
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.patch.set_alpha(0.0)
    for spine in ['left', 'right', 'top', 'bottom']:
        ax.spines[spine].set_visible(False)
    ax.annotate(f'{species_id} ({prevalence:0.0%})', xy=(0.05, 0.1), ha='left', xycoords="axes fraction")
    ax.set_xlim(left=1e-4)
    ax.set_ylim(top=300)
    
ax.xaxis.set_visible(True)
ax.spines['bottom'].set_visible(True)

fig.subplots_adjust(hspace=-0.75)

In [None]:
n_species = 40
second_species = (rabund_subset > 1e-5).sum().sort_values(ascending=False).head(n_species * 2).tail(n_species).index

fig, axs = plt.subplots(n_species, figsize=(10, 0.5 * n_species), sharex=True, sharey=True)

bins = np.logspace(-7, 0, num=51)

for species_id, ax in zip(second_species, axs):
    ax.hist(rabund_subset[species_id], bins=bins, alpha=0.7)
    ax.hist(rabund[species_id], bins=bins, alpha=0.7)
    ax.set_xscale('log')
    prevalence = (rabund_subset[species_id] > 1e-5).mean()
    ax.set_title("")
    # ax.set_xticks()
    # ax.set_yticks()
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.patch.set_alpha(0.0)
    for spine in ['left', 'right', 'top', 'bottom']:
        ax.spines[spine].set_visible(False)
    ax.annotate(f'{species_id} ({prevalence:0.0%})', xy=(0.05, 0.1), ha='left', xycoords="axes fraction")
    ax.set_xlim(left=1e-7)
    ax.set_ylim(top=300)
    
ax.xaxis.set_visible(True)
ax.spines['bottom'].set_visible(True)

fig.subplots_adjust(hspace=-0.75)

In [None]:
n_species = 40
second_species = (species_depth_subset > 1e-3).sum().sort_values(ascending=False).head(n_species * 2).tail(n_species).index

fig, axs = plt.subplots(n_species, figsize=(10, 0.5 * n_species), sharex=True, sharey=True)

bins = np.logspace(-3, 4, num=51)

for species_id, ax in zip(second_species, axs):
    ax.hist(species_depth_subset[species_id], bins=bins, alpha=0.7)
    ax.hist(species_depth[species_id], bins=bins, alpha=0.7)
    ax.set_xscale('log')
    prevalence = (species_depth_subset[species_id] > 1e-5).mean()
    ax.set_title("")
    # ax.set_xticks()
    # ax.set_yticks()
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.patch.set_alpha(0.0)
    for spine in ['left', 'right', 'top', 'bottom']:
        ax.spines[spine].set_visible(False)
    ax.annotate(f'{species_id} ({prevalence:0.0%})', xy=(0.05, 0.1), ha='left', xycoords="axes fraction")
    ax.set_xlim(left=1e-4)
    ax.set_ylim(top=300)
    
ax.xaxis.set_visible(True)
ax.spines['bottom'].set_visible(True)

fig.subplots_adjust(hspace=-0.75)

In [None]:
sns.clustermap(species_depth_subset, norm=mpl.colors.PowerNorm(1/5), metric='cosine')

In [None]:
species_depth_corr = pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(species_depth_subset.T, metric='cosine')), index=species_depth_subset.columns, columns=species_depth_subset.columns)

In [None]:
sns.clustermap(1 - species_depth_corr, figsize=(20, 20))

#### Set Focal Species (Manual)

In [None]:
species = '100760'

species_taxonomy = lib.thisproject.data.load_species_taxonomy(path["species_taxonomy"])
species_taxonomy.loc[species]

### Set Ground-truth Reference Strain

In [None]:
strain_genome = pd.read_table(path["strain_genomes"], dtype='str')
strain_genome[strain_genome.species_id == species]

In [None]:
strain_genome_ids = strain_genome[strain_genome.species_id == species].genome_id
print(strain_genome_ids)
strain_genome_id = strain_genome_ids.tolist()[0]
assert strain_genome_ids.shape[0] == 1

### Set Gene Family Clustering-level

In [None]:
centroid = 75

In [None]:
# Add new parameters.
path_params.update(dict(
    centroid=centroid,
    species=species,
    strain_genome_id=strain_genome_id,
))

# Add new patterns.
path_patterns.update(dict(
    strain_cds_length='data/species/sp-{species}/genome/{strain_genome_id}.prodigal-single.cds.nlength.tsv',
    strain_x_uhgg_bitscore_ratio='data/species/sp-{species}/genome/{strain_genome_id}.midas_uhgg_pangenome-blastn.bitscore_ratio-c{centroid}.tsv',
))

# This part is generic and should be run after ever new batch of path_patterns and path_params is added.
path = {k: path_patterns[k].format(**path_params) for k in path_patterns}
_path_exists = {}
for p in path:
    _path_exists[path[p]] = os.path.exists(path[p])
assert all(_path_exists.values()), '\n'.join(["Missing files:"] + [p for p in _path_exists if not _path_exists[p]])

### Set other SPGD parameters and check paths

In [None]:
# Add new parameters.
path_params.update(dict(
    stemB = 'filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts37-s85-seed0',
    # stemC = 'sfacts42-seed0',
    gene_params = f"99-v22-agg{centroid}",
    # thresh_params = 'thresh-corrq10-depth300',
    corr_thresh="350",
    depth_thresh="250",
    specgene_params='specgene-ref-t25-p95',
    ss_params="all",
    # ss_params="deepest-n10",
))

# Add new patterns.
path_patterns.update(dict(
    flag="data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.gene{gene_params}.spgc_{specgene_params}_ss-{ss_params}_thresh-corr{corr_thresh}-depth{depth_thresh}.strain_files.flag",
    fit="data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.world.nc",
    # refit="data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.refit-{stemC}.world.nc",
    strain_correlation="data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.gene{gene_params}.spgc_{specgene_params}_ss-{ss_params}.strain_correlation.tsv",
    strain_depth_ratio="data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.gene{gene_params}.spgc_{specgene_params}_ss-{ss_params}.strain_depth_ratio.tsv",
    strain_fraction="data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.comm.tsv",
    species_gene_mean_depth="data/group/{group}/species/sp-{species}/{stemA}.gene{gene_params}.spgc_{specgene_params}.species_depth.tsv",
    species_gtpro_depth="data/group/{group}/{stemA}.gtpro.species_depth.tsv",
    species_correlation="data/group/{group}/species/sp-{species}/{stemA}.gene{gene_params}.spgc_specgene-denovo2-t30-n500.species_correlation.tsv",
    species_gene="data/group/{group}/species/sp-{species}/{stemA}.gene{gene_params}.spgc_{specgene_params}.species_gene.list",
    species_gene_denovo="data/group/{group}/species/sp-{species}/{stemA}.gene{gene_params}.spgc_specgene-denovo-n500.species_gene.list",
    species_gene_denovo2="data/group/{group}/species/sp-{species}/{stemA}.gene{gene_params}.spgc_specgene-denovo2-t30-n500.species_gene.list",
    species_gene_reference="data/group/{group}/species/sp-{species}/{stemA}.gene{gene_params}.spgc_specgene-ref-t25-p95.species_gene.list",
    species_free_samples="data/group/{group}/species/sp-{species}/{stemA}.gene{gene_params}.spgc_specgene-ref-t25-p95.species_free_samples.list",
    strain_samples="data/group/{group}/species/sp-{species}/r.proc.gtpro.{stemB}.spgc_ss-{ss_params}.strain_samples.tsv",
    strain_thresholds="data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.gene{gene_params}.spgc_{specgene_params}_ss-{ss_params}_thresh-corr{corr_thresh}-depth{depth_thresh}.strain_gene_threshold.tsv",
    gene_annotations="ref/midasdb_uhgg_gene_annotations/sp-{species}.gene{centroid}_annotations.tsv",
    # raw_gene_depth="data/group/{group}/species/sp-{species}/{stemA}.pangenome95.gene{centroid}_depth.nc",
    # norm_gene_depth="data/group/{group}/species/sp-{species}/{stemA}.gene99-mapq0-agg{centroid}.normed_depth2.nc",
    raw_gene_depth="data/group/{group}/species/sp-{species}/{stemA}.gene{gene_params}.depth2.nc",
    # raw_gene_depth="data/group/{group}/species/sp-{species}/{stemA}.gene{centroid}.normed_depth2.nc",
    reference_copy_number="ref/midasdb_uhgg_pangenomes/{species}/gene{centroid}.reference_copy_number.nc",
    cluster_info="ref/midasdb_uhgg/pangenomes/{species}/cluster_info.txt",
    gtpro_reference_genotype="data/species/sp-{species}/gtpro_ref.mgtp.nc",
    reference_strain_accuracy="data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.gene{gene_params}.spgc_{specgene_params}_ss-{ss_params}_thresh-corr{corr_thresh}-depth{depth_thresh}.{strain_genome_id}.gene_content_reconstruction_accuracy.tsv",
    reference_strain_accuracy_depth_only="data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.gene{gene_params}.spgc_{specgene_params}_ss-{ss_params}_thresh-corr0-depth{depth_thresh}.{strain_genome_id}.gene_content_reconstruction_accuracy.tsv",
    # reference_strain_mapping_q0="data/group/{group}/species/sp-{species}/ALL_STRAINS.tiles-l100-o99.gene99-mapq0-agg{centroid}.depth2.nc",
    reference_strain_mapping_q0="data/group/{group}/species/sp-{species}/ALL_STRAINS.tiles-l100-o99.gene{gene_params}.depth2.nc",
    # reference_strain_mapping_q1="data/group/{group}/species/sp-{species}/ALL_STRAINS.tiles-l100-o99.gene99-mapq1-agg{centroid}.depth2.nc",
    # reference_strain_mapping_q2="data/group/{group}/species/sp-{species}/ALL_STRAINS.tiles-l100-o99.gene99-mapq2-agg{centroid}.depth2.nc",
    # reference_strain_mapping_q4="data/group/{group}/species/sp-{species}/ALL_STRAINS.tiles-l100-o99.gene99-mapq4-agg{centroid}.depth2.nc",
    xjin_benchmarking="data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.gene{gene_params}.spgc_{specgene_params}_ss-xjin-{ss_params}_thresh-corr{corr_thresh}-depth{depth_thresh}.xjin_strain_summary.tsv",
))

# This part is generic and should be run after ever new batch of path_patterns and path_params is added.
path = {k: path_patterns[k].format(**path_params) for k in path_patterns}
_path_exists = {}
for p in path:
    _path_exists[path[p]] = os.path.exists(path[p])
assert all(_path_exists.values()), '\n'.join(["Missing files:"] + [p for p in _path_exists if not _path_exists[p]])

In [None]:
path['species_free_samples']

In [None]:
path['flag']

### SFacts results

In [None]:
fit = sf.World.load(path['fit'])
print(fit.sizes)
np.random.seed(0)
position_ss = fit.random_sample(position=min(fit.sizes['position'], 1000)).position


fit_subset = fit.sel(sample=list(set(species_depth_subset.index) & set(fit.sample.values)))

# fuzzy_geno = sf.Genotype.load(path['fit'])  # FIXME: refit
# fuzzy_geno = sf.World.from_combined(fuzzy_geno, fit.metagenotype, fit.community)

sf.evaluation.metagenotype_error2(fit)[0]

#### Plotting

#### Identify focal sfacts strain

In [None]:
print(fit_subset.community.mean("sample").to_series().sort_values(ascending=False).head(5))
top_inferred_strain = fit_subset.community.mean("sample").to_series().idxmax()

assert fit_subset.community.mean("sample").sel(strain=top_inferred_strain) > 0.9

In [None]:
print(pd.read_table(path['reference_strain_accuracy'], index_col=0).sort_values('f1', ascending=False).loc[top_inferred_strain, ['precision', 'recall', 'f1']])
pd.read_table(path['reference_strain_accuracy'], index_col=0).sort_values('f1', ascending=False).head(5)

### Gene Annotations

In [None]:
gene_cluster = pd.read_table(
    path["cluster_info"]
).set_index('centroid_99', drop=False).rename_axis(index='gene_id')
gene_annotation = pd.read_table(
    path["gene_annotations"],
    names=['locus_tag', 'ftype', 'length_bp', 'gene', 'EC_number', 'COG', 'product'],
    index_col='locus_tag',
).rename(columns=str.lower)

gene_meta = gene_cluster.loc[gene_cluster[f'centroid_{centroid}'].unique()].join(gene_annotation)

### Reference Strain Gene Matching

In [None]:
blastp_header_names = [
    'qseqid',
    'sseqid',
    'pident',
    'length',
    'mismatch',
    'gapopen',
    'qstart',
    'qend',
    'sstart',
    'send',
    'evalue',
    'bitscore'
]

In [None]:
orf_length = pd.read_table(path['strain_cds_length'], names=['orf', 'length'], index_col=['orf']).squeeze()

orf_x_midas = pd.read_table(path['strain_x_uhgg_bitscore_ratio'], index_col=['orf', 'gene']).squeeze()
# orf_x_midas = pd.read_table('data/species/sp-101380/genome/Ruminococcus-gnavus-ATCC-29149_MinIONHybrid.midas_uhgg_pangenome-blastp.bitscore_ratio-c75.tsv', index_col=['orf', 'gene']).squeeze()


# _strain_x_strain = (
#     pd.read_table(
#         path['strain_x_strain'],
#         names=blastp_header_names
#     )
# )

# _max_bitscore = _strain_x_strain.groupby(['qseqid']).bitscore.max()

# strain_x_uhgg = (
#     pd.read_table(
#         path['strain_x_uhgg'],
#         names=blastp_header_names
#     )
#     .assign(bitscore_ratio=lambda x: x.bitscore / x.qseqid.map(_max_bitscore))
#     .assign(sseq_centroid=lambda x: x.sseqid.map(gene_cluster[f'centroid_{centroid}']))
# )

# best_uhgg_hit = strain_x_uhgg.groupby('qseqid').apply(lambda d: d.sort_values('bitscore').iloc[-1]).groupby('sseq_centroid').bitscore_ratio.max()

In [None]:
# orf_x_midas = strain_x_uhgg.groupby(['qseqid', 'sseq_centroid']).bitscore_ratio.max()


bins = np.linspace(0, 1)
plt.hist(orf_x_midas.unstack(fill_value=0).max(0), bins=bins, density=True)
plt.hist(orf_x_midas.unstack(fill_value=0).max(1), bins=bins, density=True, alpha=0.5)
plt.yscale('log')
None

In [None]:
(orf_x_midas.unstack().astype(float) > 0.95).sum(1).value_counts().sort_index()

### Strain-specific correlations/depth

In [None]:
strain_corr = pd.read_table(path["strain_correlation"], index_col=['gene_id', 'strain']).squeeze().unstack('strain', fill_value=0)
strain_depth = pd.read_table(
    path["strain_depth_ratio"],
    index_col=['gene_id', 'strain']
).squeeze().unstack()
# strain_corr, strain_depth = align_indexes(*align_indexes(strain_corr, strain_depth), axis="columns")

### Strain Metadata

In [None]:
strain_thresholds = (
    pd.read_table(path["strain_thresholds"], index_col='strain')
    .rename(columns=dict(
        # correlation_strict='corr_threshold_strict',
        correlation='corr_threshold',
        # correlation_lenient='corr_threshold_lenient',
        depth_high='depth_thresh_high',
        depth_low='depth_thresh_low',
    ))
)

_strain_meta = (
    strain_thresholds
    .join(fit.genotype.entropy().to_series().rename('genotype_entropy'))
    # .join(refit.genotype.entropy().to_series().rename('genotype_refit_entropy'))
    # .join(fit.metagenotype.entropy().to_series().rename('metagenotype_entropy').groupby(sample_to_strain).mean().rename(int))
    # .join(strain_to_sample_list.apply(len).rename('num_samples'))
    # .join(species_depth.apply(lambda x: x**(1)).groupby(sample_to_strain).std().rename('depth_stdev').rename(int))
    # .join(species_depth.apply(lambda x: x**(1)).groupby(sample_to_strain).max().rename('depth_max').rename(int))
    # .join(species_depth.apply(lambda x: x**(1)).groupby(sample_to_strain).sum().rename('depth_sum').rename(int))
    # .assign(power_index=lambda x: (x.depth_stdev * np.sqrt(x.num_samples)).fillna(0))
)
strain_meta = _strain_meta


# power_index_thresh = 5
# genotype_entropy_thresh = 0.2
# genotype_refit_entropy_thresh = 1.0

# high_power_strain_list = idxwhere(
#     (strain_meta.power_index > power_index_thresh)
#     & (strain_meta.genotype_entropy < genotype_entropy_thresh)
#     & (strain_meta.genotype_refit_entropy < genotype_refit_entropy_thresh)
# )
# print(len(high_power_strain_list))
# highest_power_strain_list = strain_meta.sort_values('power_index', ascending=False).head(3).index

# plt.scatter(strain_meta.power_index, strain_meta.corr_threshold, c=strain_meta.genotype_refit_entropy, alpha=0.5)
# plt.axvline(power_index_thresh, lw=1, linestyle='--', color='k')
# plt.colorbar()
# plt.xscale('log')

strain_meta#.loc[[top_inferred_strain]]

In [None]:
species_corr = pd.read_table(path["species_correlation"], names=['sample', 'correlation'], index_col='sample').squeeze()

### SPGC Species Genes

In [None]:
with open(path["species_gene"]) as f:
    species_gene_hit = [line.strip() for line in f]
    
with open(path["species_gene_denovo"]) as f:
    species_gene_denovo_hit = [line.strip() for line in f]

with open(path["species_gene_denovo2"]) as f:
# with open("data/group/xjin_hmp2/species/sp-100203/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts37-s85-seed0.gene99-v22-agg75.spgc.species_gene2-n500.list") as f:
    species_gene_denovo_hit2 = [line.strip() for line in f]

with open(path["species_gene_reference"]) as f:
    species_gene_reference_hit = [line.strip() for line in f]

### Strain gene definition

In [None]:
reference_strain_tile_depth_q0 = xr.load_dataarray(path['reference_strain_mapping_q0']).sel(sample=f'sp-{species}/genome/{strain_genome_id}')
# reference_strain_tile_depth_q1 = xr.load_dataarray(path['reference_strain_mapping_q1']).sel(sample=f'sp-{species}/genome/{strain_genome_id}')
# reference_strain_tile_depth_q2 = xr.load_dataarray(path['reference_strain_mapping_q2']).sel(sample=f'sp-{species}/genome/{strain_genome_id}')
# reference_strain_tile_depth_q4 = xr.load_dataarray(path['reference_strain_mapping_q4']).sel(sample=f'sp-{species}/genome/{strain_genome_id}')

bins = np.logspace(-1, 3, num=100)
plt.hist(reference_strain_tile_depth_q0 + 1e-1, bins=bins, alpha=0.5)
plt.hist(reference_strain_tile_depth_q0.reindex(gene_id=species_gene_hit, fill_value=0) + 1e-1, bins=bins, alpha=0.5)
# plt.hist(reference_strain_tile_depth_q1, bins=np.logspace(-3, 3, num=50), alpha=0.5)
# plt.hist(reference_strain_tile_depth_q2, bins=np.logspace(-3, 3, num=50), alpha=0.5)
# plt.hist(reference_strain_tile_depth_q4, bins=np.logspace(-3, 3, num=50), alpha=0.5)
# plt.yscale('log')
plt.xscale('log')
plt.yscale('log')
plt.axvline(200, lw=1, linestyle='-', color='k')
plt.axvline(50, lw=1, linestyle='--', color='k')

### Per-gene Strain Corr/Depth

In [None]:
_strain = top_inferred_strain

depth_threshold = strain_meta.depth_thresh_low.loc[_strain]
corr_threshold = strain_meta.corr_threshold.loc[_strain]
# corr_threshold = 0.95  # Set manually, but this could/should be the automatically selected threshold.
# depth_threshold = 0.2  # Set manually, but this could/should be the automatically selected threshold.
bitscore_threshold = 0.95
tile_depth_threshold = 30

strain_scores = (
    pd.DataFrame(dict(
        bitscore_ratio=orf_x_midas.unstack(fill_value=0).max(),
        strain_corr=strain_corr[_strain],
        strain_depth=strain_depth[_strain],
        species_corr=species_corr,
        tile_depth_q0=reference_strain_tile_depth_q0.to_series().reindex(strain_corr.index, fill_value=0),
        # tile_depth_q1=reference_strain_tile_depth_q1.to_series().reindex(strain_corr.index, fill_value=0),
        # tile_depth_q2=reference_strain_tile_depth_q2.to_series().reindex(strain_corr.index, fill_value=0),
        # tile_depth_q4=reference_strain_tile_depth_q4.to_series().reindex(strain_corr.index, fill_value=0),
        # strain_corr_q=strain_corr_q[_strain],
        # strain_depth_q=strain_depth_q[_strain],
    ))
    .fillna(0)
    .assign(
        bitscore_hit=lambda x: x.bitscore_ratio >= bitscore_threshold,
        not_bitscore_hit=lambda x: x.bitscore_ratio < bitscore_threshold,
        tile_depth_hit=lambda x: x.tile_depth_q0 >= tile_depth_threshold,
        not_tile_depth_hit=lambda x: x.tile_depth_q0 < tile_depth_threshold,
        depth_hit=lambda x: (x.strain_depth > depth_threshold),
        corr_and_depth_hit=lambda x: (x.strain_corr > corr_threshold) & (x.strain_depth > depth_threshold),
        species_gene=lambda x: x.index.to_series().isin(species_gene_hit),
        species_gene_denovo=lambda x: x.index.to_series().isin(species_gene_denovo_hit),
        species_gene_denovo2=lambda x: x.index.to_series().isin(species_gene_denovo_hit2),
        species_gene_reference=lambda x: x.index.to_series().isin(species_gene_reference_hit),
        corr_complement=lambda x: 1 - x.strain_corr,
        # log_tile_depth=lambda x: np.log10(x.tile_depth + 1e-4),
        dummy=False,
        gene_length=gene_cluster.groupby(f'centroid_{centroid}').centroid_99_length.mean(),
    )
    .sort_values('bitscore_ratio')
)

In [None]:
indicator_list = [
        'bitscore_hit',
        'not_bitscore_hit',
        'tile_depth_hit',
        'not_tile_depth_hit',
        'species_gene',
        'species_gene_denovo',
        'species_gene_denovo2',
        'species_gene_reference',
        'gene_length',
        'tile_depth_q0',
        # 'tile_depth_q2',
    ]

fig, axs = lib.plot.subplots_grid(ncols=2, naxes=len(indicator_list), ax_width=6, ax_height=4, sharex=True, sharey=True)

for ax, c in zip(
    axs.flatten(),
    indicator_list,
):
    
    cax = make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05)
    artist = ax.scatter(
        'corr_complement',
        'strain_depth',
        data=strain_scores.sort_values(c),
        s=1,
        c=c,
        alpha=0.9,
        cmap='rainbow',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4),
    )
    cbar = fig.colorbar(artist, cax=cax)
    cbar.solids.set_alpha(1.0)
    ax.axhline(depth_threshold, lw=1, linestyle='--')
    ax.axhline(1, xmin=0., xmax=0.5, lw=1, linestyle='--', color='k')
    ax.axvline(1 - corr_threshold, lw=1, linestyle='--')
    ax.set_xscale('symlog', linthresh=1e-3)
    ax.set_title(c)
    # TODO: xscale logit?
    ax.set_yscale('symlog', linthresh=1e-2)
    ax.set_ylim(bottom=0)
ax.invert_xaxis()
ax.set_xlabel('correlation')
ax.set_ylabel('depth ratio')

In [None]:
from sklearn.neighbors import KernelDensity

kde_refs = KernelDensity().fit(strain_scores.loc[lambda x: x.species_gene_reference][['strain_corr', 'strain_depth']].values)
kde_not_refs = KernelDensity().fit(strain_scores[['strain_corr', 'strain_depth']].values)

res_x = 50
res_y = 50
xx, yy = np.meshgrid(np.linspace(0, 1, num=res_x), np.linspace(0, 1, num=res_y))
mesh_flat = np.stack((xx, yy), axis=2).reshape((-1,2))
loglik_given_refs = kde_refs.score_samples(mesh_flat)
loglik_given_not_refs = kde_not_refs.score_samples(mesh_flat)

fig, axs = plt.subplots(2, figsize=(5, 10))

ax = axs[0]
cax = make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05)
artist = ax.pcolormesh(xx[0,:], yy[:,0], np.exp(loglik_given_refs.reshape((res_x, res_y))), vmin=0, vmax=0.2)
cbar = fig.colorbar(artist, cax=cax)

ax = axs[1]
cax = make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05)
artist = ax.pcolormesh(xx[0,:], yy[:,0], np.exp(loglik_given_not_refs.reshape((res_x, res_y))), vmin=0, vmax=0.2)
cbar = fig.colorbar(artist, cax=cax)


def posteriorA(points, priorA, kdeA, kdeB):
    numerator = priorA * np.exp(kdeA.score_samples(points))
    denominator = numerator + (1 - priorA) * np.exp(kdeB.score_samples(points))
    return numerator/denominator


fig, axs = plt.subplots(2, figsize=(5, 10))

ax = axs[0]
cax = make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05)
artist = ax.pcolormesh(xx[0,:], yy[:,0], posteriorA(mesh_flat, 0.5, kde_refs, kde_not_refs).reshape((res_x, res_y)))
cbar = fig.colorbar(artist, cax=cax)

ax = axs[1]
cax = make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05)
artist = ax.pcolormesh(xx[0,:], yy[:,0], posteriorA(mesh_flat, 0.9, kde_refs, kde_not_refs).reshape((res_x, res_y)))
cbar = fig.colorbar(artist, cax=cax)

In [None]:
corr_complement_thresh_list = np.logspace(-3, 0, num=50)
depth_thresh_list = np.logspace(-2, 0, num=50)
fig, axs = plt.subplots(3, 3, figsize=(20, 15), sharex=True, sharey=True)

for (indicator_name, indicator_hit), ax_col in zip(
    dict(
        tile_depth_hit=idxwhere(reference_strain_tile_depth_q0.to_series() > 30),
        species_gene_reference_hit=species_gene_reference_hit,
        species_gene_hit=species_gene_hit,
    ).items(),
    axs.T,
):
    ax_col[0].set_title(indicator_name)
    both_exclude = np.empty((len(corr_complement_thresh_list), len(depth_thresh_list)))
    indicator_include = np.empty((len(corr_complement_thresh_list), len(depth_thresh_list)))
    for (i, corr_complement_t), (j, depth_t) in product(
        enumerate(corr_complement_thresh_list),
        enumerate(depth_thresh_list)
    ):
        d = (
            strain_scores
            [['strain_corr', 'strain_depth']]
            .assign(
                both_exclude=lambda x: (x.strain_corr < (1 - corr_complement_t)) & (x.strain_depth < depth_t),
                both_include=lambda x: (x.strain_corr > (1 - corr_complement_t)) & (x.strain_depth > depth_t),
            )
        )
        both_exclude[i, j] = d.both_exclude.sum()
        indicator_include[i, j] = d.reindex(indicator_hit).both_include.sum()
    both_exclude = (
        pd.DataFrame(both_exclude, index=corr_complement_thresh_list, columns=depth_thresh_list)
        .rename_axis(index='corr_complement_thresh', columns='depth_thresh')
    )
    indicator_include = (
        pd.DataFrame(indicator_include, index=corr_complement_thresh_list, columns=depth_thresh_list)
        .rename_axis(index='corr_complement_thresh', columns='depth_thresh')
    )
    
    
    _a = 2
    both_exclude_trnsf = np.log(both_exclude + 1)
    indicator_include_trnsf = np.log(indicator_include + 1)
    _corr_complement_thresh, _depth_thresh = (both_exclude_trnsf + _a * indicator_include_trnsf).stack().idxmax()
    _corr_thresh = 1 - _corr_complement_thresh
    
    for d, ax in zip([both_exclude_trnsf, indicator_include_trnsf, both_exclude_trnsf + _a * indicator_include_trnsf], ax_col):
        cax = make_axes_locatable(ax).append_axes('right', size='5%', pad=0.25)
        artist = ax.pcolormesh(d.index, d.columns, d.T, norm=mpl.colors.PowerNorm(3))
        cbar = fig.colorbar(artist, cax=cax)
        ax.scatter(_corr_complement_thresh, _depth_thresh, color='r')
    print(f"{indicator_name}: corr_thresh={_corr_thresh}, depth_thresh={_depth_thresh}")
        
axs[0, 0].set_xscale('log')
axs[0, 0].invert_xaxis()
axs[0, 0].set_yscale('log')
axs[0, 0].set_ylabel('exclude')
axs[1, 0].set_ylabel('include')
axs[2, 0].set_ylabel('product')

fig.tight_layout()

In [None]:
indicator_list = [
        'bitscore_hit',
        'not_bitscore_hit',
        'tile_depth_hit',
        'not_tile_depth_hit',
        'species_gene_denovo2',
        'species_gene_reference',
        'tile_depth_q0',
    ]

fig, axs = lib.plot.subplots_grid(ncols=2, naxes=len(indicator_list), ax_width=6, ax_height=4, sharex=True, sharey=True)

for ax, c in zip(
    axs.flatten(),
    indicator_list,
):
    
    cax = make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05)  
    artist = ax.scatter(
        'corr_complement',
        'strain_depth',
        data=strain_scores.sort_values(c),
        s=1,
        c=c,
        alpha=0.9,
        cmap='rainbow',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4),
    )
    cbar = fig.colorbar(artist, cax=cax)
    cbar.solids.set_alpha(1.0)
    ax.axhline(1, xmin=0., xmax=0.5, lw=1, linestyle='--', color='k')

    ax.axhline(_depth_thresh, lw=1, linestyle='-')
    ax.axvline(_corr_complement_thresh, lw=1, linestyle='-')
    ax.axhline(depth_threshold, lw=1, linestyle='--')
    ax.axvline(1 - corr_threshold, lw=1, linestyle='--')
    
    ax.set_title(c)
    # TODO: xscale logit?
    ax.set_yscale('symlog', linthresh=1e-2)
    # ax.set_xscale('logit')
    ax.set_xscale('symlog', linthresh=1e-2)
    ax.set_ylim(bottom=0)
ax.invert_xaxis()
ax.set_xlabel('correlation')
ax.set_ylabel('depth ratio')

### Experiment: How to pick cutoffs (Laplace Smoothing)

In [None]:
bins = np.linspace(-4, 0, num=100)

x = np.log10((1 - strain_scores[lambda x: x.species_gene_denovo2].strain_corr) + 1e-5)
mean_finit_x = x[np.isfinite(x)].mean()
x = x.replace({-np.inf: mean_finit_x, 0: mean_finit_x})
# x = sp.special.expit(strain_scores[lambda x: x.species_gene_denovo2].strain_corr)


kappa0, loc0, scale0 = sp.stats.laplace_asymmetric.fit(x)
# kappa1, loc1, scale1 = sp.stats.laplace_asymmetric.fit(x, f0=1)

dist0 = sp.stats.laplace_asymmetric(kappa=kappa0, loc=loc0, scale=scale0)
dist1 = sp.stats.laplace_asymmetric(kappa=1, loc=loc0, scale=scale0)


plt.hist(x, bins=bins, density=True)
plt.axvline(dist0.ppf(0.99), color='k')
plt.axvline(dist1.ppf(0.99), color='k', linestyle='--')
plt.plot(bins, dist0.pdf(bins))
plt.plot(bins, dist1.pdf(bins))

print(kappa0, loc0, scale0)
print(1 - 10 ** dist0.ppf(0.99))

# print(kappa1, loc1, scale1)
print(1 - 10 ** dist1.ppf(0.99))


# plt.yscale('symlog', linthresh=1e-2, linscale=0.1)
# plt.xscale('log')
None

In [None]:
(strain_scores[lambda x: x.species_gene_denovo2].strain_corr).sort_values().tail(10)

In [None]:
plt.hist(strain_scores[lambda x: x.species_gene_denovo2].strain_corr, bins=200)
None

In [None]:
qq = np.linspace(0, 1, num=len(x))
x = np.log10(1 - strain_scores[lambda x: x.species_gene_denovo2].strain_corr.sort_values(ascending=False))
plt.plot([0, 1], [0, 1], lw=1, c='k')
plt.scatter(qq, dist0.cdf(x), s=1)
plt.scatter(qq, dist1.cdf(x), s=1)

In [None]:
bins = np.linspace(-2, 1, num=100)

x = np.log10(strain_scores[lambda x: x.species_gene_denovo2].strain_depth)
y = np.log10(strain_scores[lambda x: x.species_gene_reference].strain_depth)
plt.hist(x, bins=bins, density=True)
plt.hist(y, bins=bins, density=True, alpha=0.5)
None

#### 2D Thresholding

### Exploration: What is a "reference gene hit"? (tile depth)

In [None]:
(orf_x_midas.unstack().astype(float) > bitscore_threshold).T.reindex(idxwhere(strain_scores.tile_depth_q0 > tile_depth_threshold), fill_value=0).sum().value_counts().sort_index()

In [None]:
bins = [0] + list(np.logspace(-2, 1, num=100))
plt.hist(strain_scores[lambda x: x.species_gene_reference].strain_depth.sort_values(), bins=bins)
plt.xscale('symlog', linthresh=1e-4)
plt.yscale('log')

print(strain_scores[lambda x: x.species_gene_reference].strain_depth.sort_values().head())

In [None]:
bitscore_thresh = 0.95

d = pd.DataFrame(dict(
    q0=reference_strain_tile_depth_q0.to_series().reindex(strain_corr.index, fill_value=0),
    # q1=reference_strain_tile_depth_q1.to_series().reindex(strain_corr.index, fill_value=0),
    # q2=reference_strain_tile_depth_q2.to_series().reindex(strain_corr.index, fill_value=0),
    # q4=reference_strain_tile_depth_q4.to_series().reindex(strain_corr.index, fill_value=0),
    bitscore_ratio=orf_x_midas.unstack(fill_value=0).max()
))

thresh_list = [0.99, 0.95, 0.8, 0.5, 0.3, 0.1]

fig, axs = lib.plot.subplots_grid(ncols=2, naxes=len(thresh_list), ax_width=5, ax_height=3.5, sharex=True, sharey=True)
bins = np.logspace(-2, 3, num=50)

for thresh, ax in zip(thresh_list, axs.flatten()):
# for q, ax in zip(['q2'], axs):
    ax.hist('q0', data=d[(d.bitscore_ratio < thresh)], bins=bins, alpha=0.5, label='not-matched')
    ax.hist('q0', data=d[(d.bitscore_ratio >= thresh)], bins=bins, alpha=0.5, label='matched')
    ax.set_title(f'bitscore_ratio >= {thresh}')
# plt.yscale('log')
ax.set_xscale('log')
ax.set_xticks(np.logspace(-3, 3, num=7))
axs[0,0].legend()
# fig.tight_layout()

In [None]:
(orf_x_midas.unstack().astype(float) > 0.95).sum(1).value_counts().sort_index()

In [None]:
strain_scores[(strain_scores.bitscore_ratio > 0.95)][['bitscore_ratio', 'strain_corr', 'strain_depth', 'species_gene_reference', 'tile_depth_q0', 'gene_length']].sort_values('strain_depth').head(5)

In [None]:
d = strain_scores.assign(bitscore_ratio_ratio=orf_x_midas.unstack().fillna(0).divide(orf_x_midas.unstack().fillna(0).max(1), axis=0).max())

fig, ax = plt.subplots(figsize=(10, 5))
plt.scatter('corr_complement', 'tile_depth_q0', data=d, s=1, c='bitscore_ratio', cmap='viridis_r')
plt.colorbar()
plt.yscale('symlog', linthresh=1e-0)
plt.xscale('log')
ax.invert_xaxis()

In [None]:
d = strain_scores

fig, ax = plt.subplots()
plt.scatter('tile_depth_q0', 'strain_depth', c='strain_corr', data=d[lambda x: x.bitscore_ratio > 0.95], s=1, norm=mpl.colors.PowerNorm(1, vmin=0, vmax=1))
plt.colorbar()
plt.yscale('symlog', linthresh=1e-2)
plt.xscale('symlog', linthresh=1e-2)

fig, ax = plt.subplots()
plt.scatter('tile_depth_q0', 'strain_depth', c='strain_corr', data=d[lambda x: x.bitscore_ratio <= 0.95], s=1, norm=mpl.colors.PowerNorm(1, vmin=0, vmax=1))
plt.colorbar()
plt.yscale('symlog', linthresh=1e-2)
plt.xscale('symlog', linthresh=1e-2)

In [None]:
top_inferred_strain

In [None]:
pd.read_table(path['reference_strain_accuracy'], index_col=0).sort_values('f1', ascending=False).head()

### Gene Length Bias

In [None]:
d = strain_scores[lambda x: (x.bitscore_hit) & (x.strain_depth > 0)].join(gene_cluster.centroid_99_length).assign(
        log_centroid_99_length=lambda x: np.log10(x.centroid_99_length / 3),
        log_strain_depth=lambda x: np.log10(x.strain_depth),
    )

sns.regplot(
    x='log_centroid_99_length',
    y='log_strain_depth',
    data=d,
    lowess=True,
    scatter_kws=dict(s=2),
)
plt.axhline(0, lw=1, linestyle='--', color='k')
plt.ylim(-1.5, 1.5)
# plt.xscale('log')
# plt.yscale('log')

In [None]:
d = strain_scores[lambda x: (x.bitscore_hit) & (x.strain_depth > 0)].join(gene_cluster.centroid_99_length).assign(
        log_centroid_99_length=lambda x: np.log10(x.centroid_99_length / 3),
        log_strain_depth=lambda x: np.log10(x.strain_depth),
    )

sns.regplot(
    x='log_centroid_99_length',
    y='log_strain_depth',
    data=d[d.species_gene_reference],
    lowess=True,
    scatter_kws=dict(s=2),
)
plt.axhline(0, lw=1, linestyle='--', color='k')
plt.ylim(-1.5, 1.5)
# plt.xscale('log')
# plt.yscale('log')

### Comparison to Reference Strains

In [None]:
ref_geno = sf.Metagenotype.load(path['gtpro_reference_genotype'])
strain_geno = sf.Metagenotype.load(f"data/species/sp-{species}/strain_genomes.gtpro.mgtp.nc")
# ref_hits = (xr.load_dataarray(path['reference_copy_number']) >= 1).to_series().unstack('gene_id').T

In [None]:
m = sf.Metagenotype.concat(dict(ref=ref_geno, strain=strain_geno), dim='sample')
print(m.sizes)

In [None]:
strain_genome_id

In [None]:
sf.plot_metagenotype(m.sel(position=position_ss))

In [None]:
dmat = m.pdist()
_strain = f'strain_{strain_genome_id}'
print(dmat[_strain].sort_values().head(10))

top_ref_strain = "UHGG" + dmat[_strain].sort_values().index[1][len("ref_GUT_GENOME"):]

In [None]:
ref_hits = xr.open_dataarray(path['reference_copy_number']).sel(genome_id=top_ref_strain).to_series() >= 1

In [None]:
_strain_hits, _ref_hits, _inferred_hits, _depth_hits = align_indexes(
    *(
        # strain_scores.bitscore_ratio > 0.95,
        strain_scores.tile_depth_hit,
        ref_hits,
        strain_scores.corr_and_depth_hit,
        strain_scores.depth_hit,
    ),
    how='outer',
)
d0 = pd.DataFrame(dict(strain=_strain_hits, ref=_ref_hits, inf=_inferred_hits, depth=_depth_hits))
d1 = d0.value_counts().sort_index()
d1.unstack('strain', fill_value=0)

In [None]:
sns.stripplot(x='strain', y='tile_depth_q0', hue='ref', data=strain_scores.join(d0), dodge=True, s=1, alpha=0.2)
# plt.yscale('log')

In [None]:
strain_geno.to_estimated_genotype().cdist(ref_geno.to_estimated_genotype()).iloc[0].sort_values().head()

In [None]:
top_inferred_strain_sample_list = idxwhere(fit.community.data.sel(strain=top_inferred_strain).to_series() > 0.95)
allele_sorted_positions = fit.genotype.data.sel(strain=top_inferred_strain).to_series().sort_values().index

sf.plot_metagenotype(fit.metagenotype.sel(sample=top_inferred_strain_sample_list, position=allele_sorted_positions), row_cluster=False)

## Survey Across xjin Species

In [None]:
ref_strains = pd.read_table('meta/genome.tsv', index_col='genome_id')[lambda x: ~x.genome_path.isna()]

species_strain_counts = ref_strains.value_counts('species_id')

_all_species_depth = species_depth


all_species_strain_accuracy = {}
for _species in tqdm(species_strain_counts.index):
    _strain = idxwhere(ref_strains.species_id == _species)[0]

    # Format paths to reflect _species.
    _path_params = path_params.copy()
    _path_params.update(species=_species, strain_genome_id=_strain)
    _path = {k: path_patterns[k].format(**_path_params) for k in path_patterns}
    
    # Check how many strains.
    _count = species_strain_counts[_species]
    if _count > 1:
        print(f"Species {_species} has {_count} strains.")
        continue
        
    # Some species missing from _all_species_depth.
    # FIXME: This doesn't make sense with the reindexing of _species_depth below...?
    if _species in _all_species_depth:
        _species_depth = _all_species_depth[_species]
    else:
        _species_depth = np.nan
    
    if not os.path.exists(_path['flag']):
        print(f"{_species} is missing flag file")
        continue
    if not os.path.exists(_path["fit"]):
        print(f"{_species} is missing fit file")
        continue
    if not os.path.exists(_path['reference_strain_accuracy']):
        print(f"{_species} is missing accuracy file")
        continue
    _species_gene_list = pd.read_table(_path["species_gene_reference"], names=['gene_id']).gene_id.tolist()
    _fit = sf.World.load(_path["fit"])
    _thresh = pd.read_table(_path["strain_thresholds"], index_col='strain')
    _accuracy = pd.read_table(_path['reference_strain_accuracy'], index_col='strain')
    _accuracy_depth_only = pd.read_table(_path['reference_strain_accuracy_depth_only'], index_col='strain')
    _top_strain = _fit.community.sel(sample=idxwhere(_fit.community.sample.to_series().str.startswith('xjin_'))).mean("sample").to_series().idxmax()
    if not _top_strain in _accuracy.index:
        print(f"{_species} {_strain} is missing accuracy info")
        continue
    with open(_path["species_free_samples"]) as f:
        _species_free_samples = [line.strip() for line in f]
    _sample_to_strain = pd.read_table(_path["strain_samples"], index_col=['sample']).strain
    _sample_list = idxwhere(_sample_to_strain == _top_strain)
    
    _accuracy = _accuracy.join(_accuracy_depth_only, rsuffix='_depth_only').loc[_top_strain]
    _accuracy['geno_entropy'] = _fit.genotype.entropy().to_series()[_top_strain]
    _accuracy['num_strain_samples'] = len(_sample_list)
    _accuracy['num_species_free_samples'] = len(_species_free_samples)
    _accuracy['num_hmp_samples'] = sum([not s.startswith('xjin_') for s in _sample_list])
    _accuracy['num_species_genes'] = len(_species_gene_list)
    _accuracy['depth_stdev'] = _species_depth.reindex(_sample_list, fill_value=0).std()
    _accuracy['depth_max'] = _species_depth.reindex(_sample_list, fill_value=0).max()
    _accuracy['depth_sum'] = _species_depth.reindex(_sample_list, fill_value=0).sum()
    try:
        _accuracy['corr_thresh'] = _thresh.loc[_top_strain].correlation
        _accuracy['depth_thresh'] = _thresh.loc[_top_strain].depth_low
    except KeyError as err:  # FIXME: This should next happen
        _accuracy['corr_thresh'] = np.nan
        _accuracy['depth_thresh'] = np.nan
        print(err)
    all_species_strain_accuracy[_species] = _accuracy
    
all_species_strain_accuracy = (
    pd.DataFrame(all_species_strain_accuracy).T
    # .rename(columns=dict(
    #     precision_depth_only_1to1='precision_1to1_depth_only',
    #     recall_depth_only_1to1='recall_1to1_depth_only',
    #     f1_depth_only_1to1='f1_1to1_depth_only',
    # ))
    .assign(
        power_index=lambda x: (x.depth_stdev * np.sqrt(x.num_strain_samples)).fillna(0)
    )
)




In [None]:
path['flag']

In [None]:
d = all_species_strain_accuracy
fig, axs = plt.subplots(4, 3, figsize=(14, 20), sharex=True, sharey=True)

span = [5e-2, 1 - 1e-3]
span = [0, 1]

for row, c in zip(axs, ['depth_sum', 'geno_entropy', 'power_index', 'corr_thresh']):
    row[0].set_ylabel(f'depth+corr ({c})')
    for ax, _score in zip(row, ['precision', 'recall', 'f1']):
        ax.scatter(f'{_score}_depth_only', f'{_score}', s=20, c=c, norm=mpl.colors.PowerNorm(1/3), data=d)
        ax.plot(span, span)
        ax.set_aspect('equal')
        ax.set_title(_score)
    
# ax.set_yscale('logit')
# ax.set_xscale('logit')
ax.set_xlim(*span)
ax.set_ylim(*span)
axs[-1, 0].set_xlabel('depth only')
# axs[-1, 0].set_ylabel('depth+corr')

fig.tight_layout()

In [None]:
plt.scatter('depth_sum', 'geno_entropy', c='f1', data=all_species_strain_accuracy)
plt.xscale('log')

In [None]:
plt.scatter('depth_sum', 'f1', c='geno_entropy', data=all_species_strain_accuracy)
plt.xscale('log')
plt.xticks([0.001, 0.01, 0.1, 1, 10, 100, 1_000, 10_000, 100_000])
plt.colorbar()

In [None]:
(
    sp.stats.spearmanr(all_species_strain_accuracy['depth_max'], all_species_strain_accuracy['f1']),
    sp.stats.spearmanr(all_species_strain_accuracy['depth_sum'], all_species_strain_accuracy['f1']),
    sp.stats.spearmanr(all_species_strain_accuracy['depth_stdev'], all_species_strain_accuracy['f1']),
    sp.stats.spearmanr(all_species_strain_accuracy['power_index'], all_species_strain_accuracy['f1']),
    sp.stats.spearmanr(all_species_strain_accuracy['num_strain_samples'], all_species_strain_accuracy['f1']),
)

In [None]:
plt.scatter('depth_max', 'f1', c='geno_entropy', data=all_species_strain_accuracy)
plt.xscale('log')
plt.xticks([0.01, 0.1, 1, 10, 100, 1_000, 10_000])
plt.colorbar()

In [None]:
all_species_strain_accuracy.sort_values('f1', ascending=False).head(10)

In [None]:
plt.scatter('depth_max', 'f1', data=all_species_strain_accuracy, lw=1, edgecolor='k', s=40, vmin=0.5, vmax=1)
plt.ylim(0.4, 1.02)
# plt.xlim(0.4, 1.02)
plt.colorbar()
plt.xscale('log')

In [None]:
plt.scatter('precision_depth_only', 'recall_depth_only', data=all_species_strain_accuracy, c='f1', lw=1, edgecolor='k', s=40, vmin=0.5, vmax=1)
plt.ylim(0, 1.02)
plt.xlim(0, 1.02)
plt.colorbar()
# plt.xscale('log')

In [None]:
plt.scatter('precision', 'recall', data=all_species_strain_accuracy)
plt.scatter('precision_depth_only', 'recall_depth_only', data=all_species_strain_accuracy)

In [None]:
plt.hist(all_species_strain_accuracy.f1, bins=np.linspace(0, 1, num=40))
None

In [None]:
ref_strains = pd.read_table('meta/genome.tsv', index_col='genome_id')[lambda x: ~x.genome_path.isna()]
species_strain_counts = ref_strains.value_counts('species_id')
_all_species_depth = species_depth


xjin_benchmarking = []
for _species in tqdm(species_strain_counts.index):
    # Format paths to reflect _species.
    _path_params = path_params.copy()
    _path_params.update(species=_species)
    _path = {k: path_patterns[k].format(**_path_params) for k in path_patterns}
    
    if not os.path.exists(_path['xjin_benchmarking']):
        print(_path["xjin_benchmarking"])
        continue

    xjin_benchmarking.append(pd.read_table(_path["xjin_benchmarking"]).assign(species=_species))

xjin_benchmarking = pd.concat(xjin_benchmarking).assign(
    # Reasonable filters:
    to_drop=lambda x: ( False
        | (x.num_reference_genomes > 1)
        | (x.num_strain_samples != 10)  # FIXME
    )
)

In [None]:
xjin_benchmarking[~xjin_benchmarking.to_drop]