In [None]:
%load_ext autoreload

In [None]:
import os as _os
_os.chdir(_os.environ['PROJECT_ROOT'])
_os.path.realpath(_os.path.curdir)

In [None]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
from lib.pandas_util import idxwhere, align_indexes, invert_mapping
import matplotlib as mpl
import lib.plot
import statsmodels as sm
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm
import subprocess
from tempfile import mkstemp
import time
import subprocess
from itertools import chain
import os

In [None]:
import sfacts as sf

In [None]:
import lib.thisproject.data

In [None]:
sns.set_context('talk')
plt.rcParams['figure.dpi'] = 50

In [None]:
group_subset = 'xjin'
group = 'xjin_hmp2'
stemA = 'r.proc'

path = {}

path.update(dict(
    species_taxonomy="ref/gtpro/species_taxonomy_ext.tsv",
    all_species_depth_subset=f"data/group/{group_subset}/{stemA}.gtpro.species_depth.tsv",
    all_species_depth=f"data/group/{group}/{stemA}.gtpro.species_depth.tsv",
    midasdb_genomes="ref/uhgg_genomes_all_4644.tsv",
    strain_genomes="meta/genome.tsv",
))

path_exists = {}
for p in path:
    path_exists[path[p]] = os.path.exists(path[p])

assert all(path_exists.values()), '\n'.join(["Missing files:"] + [p for p in path_exists if not path_exists[p]])

In [None]:
species_depth = lib.thisproject.data.load_species_depth(path['all_species_depth'])
species_depth_subset = lib.thisproject.data.load_species_depth(path['all_species_depth_subset'])
rabund = species_depth.apply(lambda x: x / x.sum(), axis=1)
rabund_subset = species_depth_subset.apply(lambda x: x / x.sum(), axis=1)


n_species = 40
top_species = (rabund_subset > 1e-5).sum().sort_values(ascending=False).head(n_species).index

fig, axs = plt.subplots(n_species, figsize=(10, 0.5 * n_species), sharex=True, sharey=True)

bins = np.logspace(-7, 0, num=51)

for species_id, ax in zip(top_species, axs):
    ax.hist(rabund_subset[species_id], bins=bins, alpha=0.7)
    ax.hist(rabund[species_id], bins=bins, alpha=0.7)
    ax.set_xscale('log')
    prevalence = (rabund_subset[species_id] > 1e-5).mean()
    ax.set_title("")
    # ax.set_xticks()
    # ax.set_yticks()
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.patch.set_alpha(0.0)
    for spine in ['left', 'right', 'top', 'bottom']:
        ax.spines[spine].set_visible(False)
    ax.annotate(f'{species_id} ({prevalence:0.0%})', xy=(0.05, 0.1), ha='left', xycoords="axes fraction")
    ax.set_xlim(left=1e-7)
    ax.set_ylim(top=300)
    
ax.xaxis.set_visible(True)
ax.spines['bottom'].set_visible(True)

fig.subplots_adjust(hspace=-0.75)

In [None]:
species = '101380'

species_taxonomy = lib.thisproject.data.load_species_taxonomy(path["species_taxonomy"])
species_taxonomy.loc[species]

In [None]:
strain_genome = pd.read_table(path["strain_genomes"], dtype='str')
strain_genome[strain_genome.species_id == species]

In [None]:
strain_genome_ids = strain_genome[strain_genome.species_id == species].genome_id
print(strain_genome_ids)
strain_genome_id = strain_genome_ids.tolist()[0]
assert strain_genome_ids.shape[0] == 1

In [None]:
path.update(dict(
    # uhgg_x_strain=f'data/species/sp-{species}/genome/midas_uhgg_pangenome.{strain_genome_id}-blastp.tsv',
    strain_x_uhgg=f'data/species/sp-{species}/genome/{strain_genome_id}.midas_uhgg_pangenome-blastp.tsv',
    strain_x_strain=f'data/species/sp-{species}/genome/{strain_genome_id}.{strain_genome_id}-blastp.tsv',
))

path_exists = {}
for p in path:
    path_exists[path[p]] = os.path.exists(path[p])

assert all(path_exists.values()), '\n'.join(["Missing files:"] + [p for p in path_exists if not path_exists[p]])

In [None]:
# Default file path forming for interactive use.

stemB = 'filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts43-s85-seed0'
stemC = 'sfacts42-seed0'
spgc_params = 'e100'
centroid = 75

path.update(dict(
    flag=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.refit-{stemC}.gene{centroid}.spgc-{spgc_params}.strain_files.flag",
    fit=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.world.nc",
    refit=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.refit-{stemC}.world.nc",
    strain_correlation=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.gene{centroid}.spgc-{spgc_params}.strain_correlation.tsv",
    strain_depth_ratio=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.gene{centroid}.spgc-{spgc_params}.strain_depth_ratio.tsv",
    strain_fraction=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.comm.tsv",
    species_gene_mean_depth=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.gene{centroid}.spgc-{spgc_params}.species_depth.tsv",
    species_gtpro_depth=f"data/group/{group}/{stemA}.gtpro.species_depth.tsv",
    species_correlation=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.gene{centroid}.spgc.species_correlation.tsv",
    species_gene_denovo=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.gene{centroid}.spgc.species_gene.list",
    species_gene_reference=f"data/species/sp-{species}/midasuhgg.pangenome.gene{centroid}.species_gene-trim25-prev95.list",
    strain_thresholds=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.gene{centroid}.spgc-{spgc_params}.strain_gene_threshold.tsv",
    gene_annotations=f"ref/midasdb_uhgg_gene_annotations/sp-{species}.gene{centroid}_annotations.tsv",
    raw_gene_depth=f"data/group/{group}/species/sp-{species}/{stemA}.gene{centroid}.depth.nc",
    reference_copy_number=f"ref/midasdb_uhgg_pangenomes/{species}/gene{centroid}.reference_copy_number.nc",
    cluster_info=f"ref/midasdb_uhgg/pangenomes/{species}/cluster_info.txt",
    gtpro_reference_genotype=f"data/species/sp-{species}/gtpro_ref.mgtp.nc",
))

path_exists = {}
for p in path:
    path_exists[path[p]] = os.path.exists(path[p])

assert all(path_exists.values()), '\n'.join(["Missing files:"] + [p for p in path_exists if not path_exists[p]])

In [None]:
path['flag']

In [None]:
fit = sf.World.load(path['fit']).drop_low_abundance_strains(0.05)
print(fit.sizes)
position_ss = fit.random_sample(position=500).position

In [None]:
# sf.plot.plot_metagenotype(
#     fit.sel(position=position_ss),
#     # scaley=0.2, scalex=0.3,
#     row_linkage_func=lambda w: w.metagenotype.linkage("position"),
#     col_linkage_func=lambda w: w.community.linkage(),
# )
# sf.plot.plot_depth(
#     fit.sel(position=position_ss),
#     # scaley=0.2, scalex=0.3,
#     row_linkage_func=lambda w: w.metagenotype.linkage("position"),
#     col_linkage_func=lambda w: w.community.linkage(),
# )
# sf.plot.plot_dominance(
#     fit.sel(position=position_ss),
#     # scaley=0.2, scalex=0.3,
#     row_linkage_func=lambda w: w.metagenotype.linkage("position"),
#     col_linkage_func=lambda w: w.community.linkage(),
# )
sf.plot.plot_community(
    fit.sel(position=position_ss),
    # scaley=0.2, scalex=0.3,
    col_linkage_func=lambda w: w.community.linkage(),
    row_linkage_func=lambda w: w.genotype.linkage("strain"),
)
# sf.plot.plot_genotype(
#     fit.sel(position=position_ss),
#     # scaley=0.2, scalex=0.3,
#     col_linkage_func=lambda w: w.metagenotype.linkage("position"),
#     row_linkage_func=lambda w: w.genotype.linkage("strain"),
# )

In [None]:
print(fit.community.mean("sample").to_series().sort_values(ascending=False).head(5))
top_inferred_strain = fit.community.mean("sample").to_series().idxmax()

assert fit.community.mean("sample").sel(strain=top_inferred_strain) > 0.1

In [None]:
gene_cluster = pd.read_table(
    path["cluster_info"]
).set_index('centroid_99', drop=False).rename_axis(index='gene_id')
gene_annotation = pd.read_table(
    path["gene_annotations"],
    names=['locus_tag', 'ftype', 'length_bp', 'gene', 'EC_number', 'COG', 'product'],
    index_col='locus_tag',
).rename(columns=str.lower)

gene_meta = gene_cluster.loc[gene_cluster[f'centroid_{centroid}'].unique()].join(gene_annotation)

In [None]:
blastp_header_names = [
    'qseqid',
    'sseqid',
    'pident',
    'length',
    'mismatch',
    'gapopen',
    'qstart',
    'qend',
    'sstart',
    'send',
    'evalue',
    'bitscore'
]

In [None]:
_strain_x_strain = (
    pd.read_table(
        path['strain_x_strain'],
        names=blastp_header_names
    )
)

_max_bitscore = _strain_x_strain.groupby(['qseqid']).bitscore.max()

strain_x_uhgg = (
    pd.read_table(
        path['strain_x_uhgg'],
        names=blastp_header_names
    )
    .assign(bitscore_ratio=lambda x: x.bitscore / x.qseqid.map(_max_bitscore))
    .assign(sseq_centroid=lambda x: x.sseqid.map(gene_cluster[f'centroid_{centroid}']))
)

best_uhgg_hit = strain_x_uhgg.groupby('qseqid').apply(lambda d: d.sort_values('bitscore').iloc[-1]).groupby('sseq_centroid').bitscore_ratio.max()

In [None]:
orf_x_midas = strain_x_uhgg.groupby(['qseqid', 'sseq_centroid']).bitscore_ratio.max()


bins = np.linspace(0, 1)
plt.hist(orf_x_midas.unstack(fill_value=0).max(0), bins=bins, density=True)
plt.hist(orf_x_midas.unstack(fill_value=0).max(1), bins=bins, density=True, alpha=0.5)
plt.yscale('log')
None

In [None]:
(orf_x_midas.unstack().astype(float) > 0.5).sum(1).value_counts().sort_index()

In [None]:
plt.scatter('pident', 'bitscore_ratio', data=strain_x_uhgg, s=1)
plt.plot([0, 100], [0, 1], color='k', linestyle='--')

In [None]:
strain_corr = pd.read_table(path["strain_correlation"], index_col=['gene_id', 'strain']).squeeze().unstack('strain', fill_value=0)
strain_depth = pd.read_table(
    path["strain_depth_ratio"],
    index_col=['gene_id', 'strain']
).squeeze().unstack()
# strain_corr, strain_depth = align_indexes(*align_indexes(strain_corr, strain_depth), axis="columns")

In [None]:
strain_thresholds = (
    pd.read_table(path["strain_thresholds"], index_col='strain')
    .rename(columns=dict(
        correlation_strict='corr_threshold_strict',
        correlation_moderate='corr_threshold_moderate',
        correlation_lenient='corr_threshold_lenient',
        depth_high='depth_thresh_high',
        depth_low='depth_thresh_low',
    ))
)

_strain_meta = (
    strain_thresholds
    .join(fit.genotype.entropy().to_series().rename('genotype_entropy'))
    # .join(refit.genotype.entropy().to_series().rename('genotype_refit_entropy'))
    # .join(fit.metagenotype.entropy().to_series().rename('metagenotype_entropy').groupby(sample_to_strain).mean().rename(int))
    # .join(strain_to_sample_list.apply(len).rename('num_samples'))
    # .join(species_depth.apply(lambda x: x**(1)).groupby(sample_to_strain).std().rename('depth_stdev').rename(int))
    # .join(species_depth.apply(lambda x: x**(1)).groupby(sample_to_strain).max().rename('depth_max').rename(int))
    # .join(species_depth.apply(lambda x: x**(1)).groupby(sample_to_strain).sum().rename('depth_sum').rename(int))
    # .assign(power_index=lambda x: (x.depth_stdev * np.sqrt(x.num_samples)).fillna(0))
)
strain_meta = _strain_meta


# power_index_thresh = 5
# genotype_entropy_thresh = 0.2
# genotype_refit_entropy_thresh = 1.0

# high_power_strain_list = idxwhere(
#     (strain_meta.power_index > power_index_thresh)
#     & (strain_meta.genotype_entropy < genotype_entropy_thresh)
#     & (strain_meta.genotype_refit_entropy < genotype_refit_entropy_thresh)
# )
# print(len(high_power_strain_list))
# highest_power_strain_list = strain_meta.sort_values('power_index', ascending=False).head(3).index

# plt.scatter(strain_meta.power_index, strain_meta.corr_threshold_moderate, c=strain_meta.genotype_refit_entropy, alpha=0.5)
# plt.axvline(power_index_thresh, lw=1, linestyle='--', color='k')
# plt.colorbar()
# plt.xscale('log')

strain_meta

In [None]:
species_corr = pd.read_table(path["species_correlation"], names=['sample', 'correlation'], index_col='sample').squeeze()

In [None]:
with open(path["species_gene_denovo"]) as f:
    species_gene_denovo_hit = [line.strip() for line in f]

In [None]:
with open(path["species_gene_reference"]) as f:
    species_gene_reference_hit = [line.strip() for line in f]

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

depth_threshold = strain_meta.depth_thresh_low.loc[top_inferred_strain]
corr_threshold = strain_meta.corr_threshold_moderate.loc[top_inferred_strain]
# corr_threshold = 0.95  # Set manually, but this could/should be the automatically selected threshold.
# depth_threshold = 0.2  # Set manually, but this could/should be the automatically selected threshold.
bitscore_threshold = 0.5

strain_scores = (
    pd.DataFrame(dict(
        bitscore_ratio=best_uhgg_hit,
        strain_corr=strain_corr[top_inferred_strain],
        strain_depth=strain_depth[top_inferred_strain],
        species_corr=species_corr,
    ))
    .fillna(0)
    .assign(
        bitscore_hit=lambda x: x.bitscore_ratio > bitscore_threshold,
        corr_and_depth_hit=lambda x: (x.strain_corr > corr_threshold) & (x.strain_depth > depth_threshold),
        species_gene_denovo=lambda x: x.index.to_series().isin(species_gene_denovo_hit),
        species_gene_reference=lambda x: x.index.to_series().isin(species_gene_reference_hit),
        corr_complement=lambda x: 1 - x.strain_corr,
    )
    .sort_values('bitscore_ratio')
)
strain_scores

fig, axs = plt.subplots(3, figsize=(5, 10), sharex=True, sharey=True)


for ax, c in zip(axs.flatten(), ['bitscore_hit', 'species_gene_denovo', 'species_gene_reference']):
    ax.scatter(
        'corr_complement',
        'strain_depth',
        data=strain_scores[lambda x: ~x[c]],
        s=5,
        # c=c,
        alpha=0.1   
    )
    ax.scatter(
        'corr_complement',
        'strain_depth',
        data=strain_scores[lambda x: x[c]],
        s=5,
        # c=c,
        alpha=0.5   
    )
    ax.axhline(depth_threshold, lw=1, linestyle='--')
    ax.axhline(1, xmin=0., xmax=0.5, lw=1, linestyle='--', color='k')
    ax.axvline(1 - corr_threshold, lw=1, linestyle='--')
    ax.set_xscale('log')
    ax.set_yscale('symlog', linthresh=1e-1)
    ax.set_ylim(bottom=0)
    ax.invert_xaxis()

In [None]:
from itertools import product

midas_gene_list = idxwhere(strain_scores.strain_depth > 0)

_correlation = strain_scores.strain_corr[midas_gene_list]
_depth = strain_scores.strain_depth[midas_gene_list]
_bitscore = orf_x_midas.unstack().reindex(columns=midas_gene_list, fill_value=0).astype(float)

b_thresh = 0.5
_bitscore_hit = _bitscore > b_thresh
# Only consider
#  - MIDAS genes that have <= 1 hit
#  - ORFs that (1) don't hit any MIDAS genes with > 1 hit
#  - and (2) hit exactly 1 MIDAS gene
_midas_multi_hit = _bitscore_hit.sum(0) > 1
_midas_1to1 = ~_midas_multi_hit
_orf_1to1 = (_bitscore_hit.loc[:, _midas_1to1].sum(1) == 1) & ~_bitscore_hit.loc[:, _midas_multi_hit].any(1)
_midas_1to1 = set(idxwhere(_midas_1to1))
_orf_1to1 = set(idxwhere(_orf_1to1))

c_thresh_list_complement = np.array(list(reversed([0] + list(np.logspace(-3, 0, num=21)))))
c_thresh_list = 1 - c_thresh_list_complement
d_thresh_list = np.linspace(0, 1.0, num=11)

precision_result = np.empty((len(d_thresh_list), len(c_thresh_list)))
recall_result = np.empty_like(precision_result)
precision_result_1to1 = np.empty_like(precision_result)
recall_result_1to1 = np.empty_like(precision_result)

for (i, d_thresh), (j, c_thresh) in tqdm(
    product(enumerate(d_thresh_list), enumerate(c_thresh_list)),
    total=len(d_thresh_list) * len(c_thresh_list)
):
    _spgc_hit = (_correlation >= c_thresh) & (_depth >= d_thresh)
    tp_midas = set(idxwhere(_bitscore_hit.loc[:, _spgc_hit].any(axis=0)))  # MIDAS genes hit by both.
    fp_midas = set(idxwhere(~(_bitscore_hit.loc[:, _spgc_hit].any())))  # MIDAS genes were hit by SPGC but never by BLAST?
    fn_orf = set(idxwhere(~(_bitscore_hit.loc[:, _spgc_hit].any(axis=1))))  # How many ORFs were hit by BLAST but no matching SPGC hits?
    tp_orf = set(idxwhere((_spgc_hit & _bitscore_hit).any(axis=1)))  # How many ORFs were hit by BLAST and by SPGC?
    
    n_tp_midas = len(tp_midas)
    n_fp_midas = len(fp_midas)
    n_fn_orf = len(fn_orf)
    n_tp_orf = len(tp_orf)
    
    n_tp_1to1 = len(tp_midas & _midas_1to1)
    n_fp_1to1 = len(fp_midas & _midas_1to1)
    n_fn_1to1 = len(fn_orf & _orf_1to1)
    n_tn_1to1 = len(set(_spgc_hit) & _midas_1to1) - (n_tp_1to1 + n_fp_1to1 + n_fn_1to1)

    
    if (n_tp_midas + n_fp_midas) != 0:
        precision_result[i, j] = n_tp_midas / (n_tp_midas + n_fp_midas)
    else:
        precision_result[i, j] = 1

    if (n_tp_orf + n_fn_orf) != 0:
        recall_result[i, j] = n_tp_orf / (n_tp_orf + n_fn_orf)
    else:
        recall_result[i, j] = 0
        
    if (n_tp_1to1 + n_fp_1to1) != 0:
        precision_result_1to1[i, j] = n_tp_1to1 / (n_tp_1to1 + n_fp_1to1)
    else:
        precision_result_1to1[i, j] = 1

    if (n_tp_1to1 + n_fn_1to1) != 0:
        recall_result_1to1[i, j] = n_tp_1to1 / (n_tp_1to1 + n_fn_1to1)
    else:
        recall_result_1to1[i, j] = 0
    
precision_result = pd.DataFrame(precision_result, index=d_thresh_list, columns=c_thresh_list).rename_axis(index='depth_threshold', columns='correlation_threshold')
recall_result = pd.DataFrame(recall_result, index=d_thresh_list, columns=c_thresh_list).rename_axis(index='depth_threshold', columns='correlation_threshold')
f1_result = pd.DataFrame(sp.stats.hmean(np.stack([precision_result, recall_result])), index=d_thresh_list, columns=c_thresh_list).rename_axis(index='depth_threshold', columns='correlation_threshold')

precision_result_1to1 = pd.DataFrame(precision_result_1to1, index=d_thresh_list, columns=c_thresh_list).rename_axis(index='depth_threshold', columns='correlation_threshold')
recall_result_1to1 = pd.DataFrame(recall_result_1to1, index=d_thresh_list, columns=c_thresh_list).rename_axis(index='depth_threshold', columns='correlation_threshold')
f1_result_1to1 = pd.DataFrame(sp.stats.hmean(np.stack([precision_result_1to1, recall_result_1to1])), index=d_thresh_list, columns=c_thresh_list).rename_axis(index='depth_threshold', columns='correlation_threshold')


    # print(f'{c_thresh}\t{d_thresh}\t{precision_midas:0.2f}\t{recall_orf:0.2f}\t{n_fp_midas}\t{n_fn_orf}\t{n_tp_midas}\t{n_tp_orf}')

# MIDAS genes that map to multiple ORFs, tend to have higher depth ratios.
#sns.regplot(x=_bitscore_hit.loc[tp_orf, tp_midas].sum(axis=0), y=strain_scores.loc[tp_midas].strain_depth)

In [None]:
fig, axs = plt.subplots(3, figsize=(5, 10))

ax = axs[0]
ax.set_title('precision')
cax = make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05)
artist = ax.pcolormesh(c_thresh_list_complement, d_thresh_list, precision_result, norm=mpl.colors.PowerNorm(1, vmin=0, vmax=1))
ax.set_xscale('symlog', linthresh=1e-2, linscale=0.5)
ax.set_yscale('symlog', linthresh=1e-1, linscale=0.5)
ax.set_ylim(0)
ax.invert_xaxis()
cbar = fig.colorbar(artist, cax=cax)
ax.axhline(strain_meta.depth_thresh_low.loc[top_inferred_strain], lw=1, linestyle='--', color='r')
ax.axvline(1 - strain_meta.corr_threshold_moderate.loc[top_inferred_strain], lw=1, linestyle='--', color='r')

ax = axs[1]
ax.set_title('recall')
cax = make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05)
artist = ax.pcolormesh(c_thresh_list_complement, d_thresh_list, recall_result, norm=mpl.colors.PowerNorm(1, vmin=0, vmax=1))
ax.set_xscale('symlog', linthresh=1e-2, linscale=0.5)
ax.set_yscale('symlog', linthresh=1e-1, linscale=0.5)
ax.set_ylim(0)
ax.invert_xaxis()
cbar = fig.colorbar(artist, cax=cax)
ax.axhline(strain_meta.depth_thresh_low.loc[top_inferred_strain], lw=1, linestyle='--', color='r')
ax.axvline(1 - strain_meta.corr_threshold_moderate.loc[top_inferred_strain], lw=1, linestyle='--', color='r')

ax = axs[2]
ax.set_title('f1')
cax = make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05)
artist = ax.pcolormesh(c_thresh_list_complement, d_thresh_list, f1_result, norm=mpl.colors.PowerNorm(1, vmin=0, vmax=1))
ax.set_xscale('symlog', linthresh=1e-2, linscale=0.5)
ax.set_yscale('symlog', linthresh=1e-1, linscale=0.5)
ax.set_ylim(0)
ax.invert_xaxis()
cbar = fig.colorbar(artist, cax=cax)
ax.axhline(strain_meta.depth_thresh_low.loc[top_inferred_strain], lw=1, linestyle='--', color='r')
ax.axvline(1 - strain_meta.corr_threshold_moderate.loc[top_inferred_strain], lw=1, linestyle='--', color='r')

fig.tight_layout()

In [None]:
fig, axs = plt.subplots(3, figsize=(5, 10))

ax = axs[0]
ax.set_title('precision')
cax = make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05)
artist = ax.pcolormesh(c_thresh_list_complement, d_thresh_list, precision_result_1to1, norm=mpl.colors.PowerNorm(1, vmin=0, vmax=1))
ax.set_xscale('symlog', linthresh=1e-2, linscale=0.5)
ax.set_yscale('symlog', linthresh=1e-1, linscale=0.5)
ax.set_ylim(0)
ax.invert_xaxis()
cbar = fig.colorbar(artist, cax=cax)
ax.axhline(strain_meta.depth_thresh_low.loc[top_inferred_strain], lw=1, linestyle='--', color='r')
ax.axvline(1 - strain_meta.corr_threshold_moderate.loc[top_inferred_strain], lw=1, linestyle='--', color='r')

ax = axs[1]
ax.set_title('recall')
cax = make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05)
artist = ax.pcolormesh(c_thresh_list_complement, d_thresh_list, recall_result_1to1, norm=mpl.colors.PowerNorm(1, vmin=0, vmax=1))
ax.set_xscale('symlog', linthresh=1e-2, linscale=0.5)
ax.set_yscale('symlog', linthresh=1e-1, linscale=0.5)
ax.set_ylim(0)
ax.invert_xaxis()
cbar = fig.colorbar(artist, cax=cax)
ax.axhline(strain_meta.depth_thresh_low.loc[top_inferred_strain], lw=1, linestyle='--', color='r')
ax.axvline(1 - strain_meta.corr_threshold_moderate.loc[top_inferred_strain], lw=1, linestyle='--', color='r')

ax = axs[2]
ax.set_title('f1')
cax = make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05)
artist = ax.pcolormesh(c_thresh_list_complement, d_thresh_list, f1_result_1to1, norm=mpl.colors.PowerNorm(1, vmin=0, vmax=1))
ax.set_xscale('symlog', linthresh=1e-2, linscale=0.5)
ax.set_yscale('symlog', linthresh=1e-1, linscale=0.5)
ax.set_ylim(0)
ax.invert_xaxis()
cbar = fig.colorbar(artist, cax=cax)
ax.axhline(strain_meta.depth_thresh_low.loc[top_inferred_strain], lw=1, linestyle='--', color='r')
ax.axvline(1 - strain_meta.corr_threshold_moderate.loc[top_inferred_strain], lw=1, linestyle='--', color='r')

fig.tight_layout()

In [None]:
fig, axs = plt.subplots(3, figsize=(5, 10))

ax = axs[0]
ax.set_title('precision-diff')
cax = make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05)
artist = ax.pcolormesh(c_thresh_list_complement, d_thresh_list, precision_result_1to1 - precision_result)
ax.set_xscale('symlog', linthresh=1e-2, linscale=0.5)
ax.set_yscale('symlog', linthresh=1e-1, linscale=0.5)
ax.set_ylim(0)
ax.invert_xaxis()
cbar = fig.colorbar(artist, cax=cax)
ax.axhline(strain_meta.depth_thresh_low.loc[top_inferred_strain], lw=1, linestyle='--', color='r')
ax.axvline(1 - strain_meta.corr_threshold_moderate.loc[top_inferred_strain], lw=1, linestyle='--', color='r')

ax = axs[1]
ax.set_title('recall-diff')
cax = make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05)
artist = ax.pcolormesh(c_thresh_list_complement, d_thresh_list, recall_result_1to1 - recall_result)
ax.set_xscale('symlog', linthresh=1e-2, linscale=0.5)
ax.set_yscale('symlog', linthresh=1e-1, linscale=0.5)
ax.set_ylim(0)
ax.invert_xaxis()
cbar = fig.colorbar(artist, cax=cax)
ax.axhline(strain_meta.depth_thresh_low.loc[top_inferred_strain], lw=1, linestyle='--', color='r')
ax.axvline(1 - strain_meta.corr_threshold_moderate.loc[top_inferred_strain], lw=1, linestyle='--', color='r')

ax = axs[2]
ax.set_title('f1-diff')
cax = make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05)
artist = ax.pcolormesh(c_thresh_list_complement, d_thresh_list, f1_result_1to1 - f1_result)
ax.set_xscale('symlog', linthresh=1e-2, linscale=0.5)
ax.set_yscale('symlog', linthresh=1e-1, linscale=0.5)
ax.set_ylim(0)
ax.invert_xaxis()
cbar = fig.colorbar(artist, cax=cax)
ax.axhline(strain_meta.depth_thresh_low.loc[top_inferred_strain], lw=1, linestyle='--', color='r')
ax.axvline(1 - strain_meta.corr_threshold_moderate.loc[top_inferred_strain], lw=1, linestyle='--', color='r')

fig.tight_layout()

In [None]:
_correlation = strain_scores.strain_corr[midas_gene_list]
_depth = strain_scores.strain_depth[midas_gene_list]
_bitscore = orf_x_midas.unstack().reindex(columns=midas_gene_list, fill_value=0).astype(float)

b_thresh = 0.5
_bitscore_hit = _bitscore > b_thresh
# Only consider
#  - MIDAS genes that have <= 1 hit
#  - ORFs that (1) don't hit any MIDAS genes with > 1 hit
#  - and (2) hit exactly 1 MIDAS gene
_midas_multi_hit = _bitscore_hit.sum(0) > 1
_midas_1to1 = ~_midas_multi_hit
_orf_1to1 = (_bitscore_hit.sum(1) == 1) & ~_bitscore_hit.loc[:, _midas_multi_hit].any(1)
_midas_1to1 = set(idxwhere(_midas_1to1))
_orf_1to1 = set(idxwhere(_orf_1to1))

d_thresh = strain_meta.depth_thresh_low.loc[top_inferred_strain]
c_thresh = strain_meta.corr_threshold_moderate.loc[top_inferred_strain]

c_thresh_list_complement = np.array(list(reversed([0] + list(np.logspace(-3, 0, num=21)))))
c_thresh_list = 1 - c_thresh_list_complement
d_thresh_list = np.linspace(0, 1.0, num=11)

_spgc_hit = (_correlation >= c_thresh) & (_depth >= d_thresh)
tp_midas = set(idxwhere(_bitscore_hit.loc[:, _spgc_hit].any(axis=0)))  # MIDAS genes hit by both.
fp_midas = set(idxwhere(~(_bitscore_hit.loc[:, _spgc_hit].any())))  # MIDAS genes were hit by SPGC but never by BLAST?
fn_orf = set(idxwhere(~(_bitscore_hit.loc[:, _spgc_hit].any(axis=1))))  # How many ORFs were hit by BLAST but no matching SPGC hits?
tp_orf = set(idxwhere((_spgc_hit & _bitscore_hit).any(axis=1)))  # How many ORFs were hit by BLAST and by SPGC?

n_tp_midas = len(tp_midas)
n_fp_midas = len(fp_midas)
n_fn_orf = len(fn_orf)
n_tp_orf = len(tp_orf)

n_tp_1to1 = len(tp_midas & _midas_1to1)
n_fp_1to1 = len(fp_midas & _midas_1to1)
n_fn_1to1 = len(fn_orf & _orf_1to1)
n_tn_1to1 = len(set(_spgc_hit) & _midas_1to1) - (n_tp_1to1 + n_fp_1to1 + n_fn_1to1)

precision = n_tp_midas / (n_tp_midas + n_fp_midas)
recall = n_tp_orf / (n_tp_orf + n_fn_orf)
precision_1to1 = n_tp_1to1 / (n_tp_1to1 + n_fp_1to1)
recall_1to1 = n_tp_1to1 / (n_tp_1to1 + n_fn_1to1)

print(f"{precision:.4f} ({precision_1to1:.4f}) {recall:.4f} ({recall_1to1:.4f})")

In [None]:
fig, ax = plt.subplots()

ax.contour(c_thresh_list_complement, d_thresh_list, f1_result, levels=np.linspace(0, 1, num=11))

cax = make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05)
artist = ax.scatter(
    'corr_complement',
    'strain_depth',
    data=strain_scores.assign(corr_complement=lambda x: 1 - x.strain_corr),
    s=5,
    # c='bitscore_ratio',
    norm=mpl.colors.LogNorm(),
    alpha=0.1,
    cmap='viridis_r',
)
cbar = fig.colorbar(artist, cax=cax)


ax.set_xscale('log')
ax.set_yscale('symlog', linthresh=1e-1, linscale=0.5)
ax.set_ylim(0)
ax.invert_xaxis()
ax.axhline(strain_meta.depth_thresh_low.loc[top_inferred_strain], lw=1, linestyle='--', color='r')
ax.axvline(1 - strain_meta.corr_threshold_moderate.loc[top_inferred_strain], lw=1, linestyle='--', color='r')

In [None]:
d = strain_scores[lambda x: (x.bitscore_hit) & (x.strain_depth > 0)].join(gene_cluster.centroid_99_length).assign(
        log_centroid_99_length=lambda x: np.log10(x.centroid_99_length / 3),
        log_strain_depth=lambda x: np.log10(x.strain_depth),
    )

sns.regplot(
    x='log_centroid_99_length',
    y='log_strain_depth',
    data=d,
    lowess=True,
    scatter_kws=dict(s=2),
)
plt.axhline(0, lw=1, linestyle='--', color='k')
plt.ylim(-1.5, 1.5)
# plt.xscale('log')
# plt.yscale('log')

In [None]:
bins = np.logspace(0, 4, num=50)

d = orf_x_midas.unstack().reindex(columns=midas_gene_list, fill_value=0).astype(float).fillna(0)


plt.hist(gene_annotation.join(strain_scores).loc[list(fp_midas)].length_bp, bins=bins, alpha=0.5, density=True, label='FP')
plt.hist(gene_annotation.join(strain_scores).loc[d.loc[list(fn_orf)].idxmax(1)].length_bp, bins=bins, alpha=0.5, density=True, label='FN')
# plt.hist(gene_annotation.join(strain_scores).loc[d.loc[list(tp_orf)].idxmax(1)].length_bp, bins=bins, alpha=0.5, density=True)
plt.hist(gene_annotation.join(strain_scores).loc[list(tp_midas)].length_bp, bins=bins, alpha=0.5, density=True, label='TP')
plt.xscale('log')
plt.xlabel('MIDAS Gene Centroid Length')
plt.ylabel('Relative Number of Genes')
plt.legend()
None

In [None]:
gene_annotation.join(strain_scores).loc[list(fp_midas)].sort_values('bitscore_ratio')