## Preamble

### Project Template

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

### Imports

In [None]:
import os
import subprocess
import sys
import time
from datetime import datetime
from glob import glob
from itertools import chain, product
from tempfile import mkstemp

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import statsmodels.api as sm
import statsmodels.formula.api as smf
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

## Style

In [None]:
sns.set_context("paper")
plt.rcParams["figure.dpi"] = 300

In [None]:
mpl.scale.register_scale(mpl.scale.FuncScale)

# Load and Check Data

In [None]:
genome_list = idxwhere(
    pd.read_table("meta/genome_group.tsv", index_col="genome_id").genome_group_id
    == "xjin"
)
len(genome_list)

In [None]:
genome = pd.read_table("meta/genome.tsv", index_col="genome_id").loc[genome_list]
assert genome.index.is_unique

In [None]:
species_list = list(
    genome.loc[genome_list][
        lambda x: (x.species_id != "TODO") & (x.genome_path != "")
    ].species_id.unique()
)
len(species_list)

In [None]:
with open("fig/xjin_benchmark_species_list.txt", "w") as f:
    for species in species_list:
        print(species, file=f)

### Taxonomy

In [None]:
def parse_taxonomy_string(taxonomy_string):
    values = taxonomy_string.split(";")
    return pd.Series(values, index=["d__", "p__", "c__", "o__", "f__", "g__", "s__"])

In [None]:
species_taxonomy_inpath = "ref/uhgg_genomes_all_v2.tsv"

species_taxonomy = (
    pd.read_table(species_taxonomy_inpath)[lambda x: x.Genome == x.Species_rep]
    .assign(species_id=lambda x: "1" + x.MGnify_accession.str.split("-").str[2])
    .set_index("species_id")
    .Lineage.apply(parse_taxonomy_string)
)
species_taxonomy

In [None]:
species_taxonomy.loc[species_list]

In [None]:
phylum_order = [
    "p__Euryarchaeota",
    "p__Thermoplasmatota",
    "p__Firmicutes",
    "p__Firmicutes_A",
    "p__Firmicutes_C",
    # "p__Firmicutes_B", # None in species_list1
    # "p__Firmicutes_G", # B/G/I not sure how related to C or A
    # "p__Firmicutes_I", #
    # "p__Cyanobacteria", # None in species_list1
    "p__Actinobacteriota",
    "p__Synergistota",
    "p__Fusobacteriota",
    "p__Campylobacterota",
    "p__Proteobacteria",
    "p__Desulfobacterota_A",
    "p__Bacteroidota",
    "p__Verrucomicrobiota",
    # "dummy0", # 18
    # "dummy1", # 19
    # "dummy2", # 20
]

phylum_palette = lib.plot.construct_ordered_palette(
    phylum_order,
    cm="rainbow",
    desaturate_levels=[1.0, 0.5],
)

for p__ in phylum_order:
    print(p__, phylum_palette[p__])
    plt.scatter([], [], color=phylum_palette[p__], label=p__)
plt.legend(ncols=1)
lib.plot.hide_axes_and_spines()

# assert len(set(phylum_palette.values())) == len((phylum_palette.values()))

In [None]:
species_group = (
    pd.read_table("meta/species_group.tsv")[lambda x: x.species_group_id == "xjin"]
    .species_id.astype(str)
    .to_list()
)
species_group[:5]

In [None]:
genome[lambda x: x.species_id == "100003"]

In [None]:
qc_code_meaning = {
    -1: "fail",
    0: "passes",
    1: "noise",
    2: "species-gene",
    3: "both",
}


def assign_qc_code(x):
    if x.passes_species_gene_frac:
        if x.passes_log_selected_gene_depth_ratio_std:
            return 0  # Nothing wrong
        else:
            return 1  # Doesn't pass noise threshold
    elif x.passes_log_selected_gene_depth_ratio_std:
        return 2  # Doesn't pass species_gene_frac
    else:
        return 3  # Doesn't pass either

In [None]:
# data/group/xjin/r.proc.gtpro.sfacts-fit.gene99_v15-v22-agg75.spgc_specgene-ref-t25-p95.STRAIN_MATCH_BENCHMARK_GRID.flag

strain_match = {}
for genome_id, d in genome.loc[genome_list].iterrows():
    species = d.species_id
    if species == "TODO":
        continue
    strain_match_path = f"data/group/xjin/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_v15-v22-agg75.spgc-fit.{genome_id}.geno_matching_stats.tsv"
    strain_meta_path = f"data/group/xjin/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_v15-v22-agg75.spgc-fit.strain_meta-s95-d100-a0-pos100-std25.tsv"
    if os.path.exists(strain_match_path):
        strain_meta = pd.read_table(
            strain_meta_path, index_col="genome_id"
        ).rename_axis("strain")
        strain_match[genome_id] = (
            pd.read_table(strain_match_path, index_col=["genome_id", "strain"])
            .assign(
                genotype_matching_positions=lambda x: (1 - x.genotype_dissimilarity)
                * x.num_geno_positions_compared,
                genotype_dissimilarity_pc=lambda x: x.genotype_dissimilarity
                + (1 / x.num_geno_positions_compared),
                strain_match_path=strain_match_path,
            )
            .join(strain_meta, on="strain")
        )
strain_match = (
    pd.concat(strain_match.values()).reset_index().set_index(["genome_id", "strain"])
).assign(qc_code=lambda x: x.apply(assign_qc_code, axis=1))
strain_match

In [None]:
# NOTE: We match strains only from among those with "strain-pure-samples" in the xjin set
# and then picking the lowest genotype_dissimilarity.
# TODO: Consider also allowing strains to have strain pure samples
# outside the xjin set, but also be detected at appreciable relative abundances
# inside the xjin samples?
genome_to_spgc_strain = (
    strain_match[lambda x: x.strain_depth_max > 0]
    .reset_index()
    .sort_values(
        ["genotype_dissimilarity_pc", "strain_depth_sum"], ascending=(True, False)
    )
    .groupby("genome_id")
    .head(1)
    .set_index("genome_id")
    .reindex(genome_list)
    .fillna(
        {
            "strain": -1,
            "genotype_dissimilarity": 1.0,
            "strain_depth_sum": 0,
            "strain_depth_max": 0,
            "passes_total_depth": False,
            "passes_species_gene_frac": False,
            "passes_gene_count": False,
            "passes_log_selected_gene_depth_ratio_std": False,
            "passes_geno_positions": False,
            "passes_filter": False,
            "qc_code": -1,
        }
    )
    .astype({"strain": int, "qc_code": int})
)

plt.hist(
    genome_to_spgc_strain.genotype_dissimilarity_pc, bins=[0] + list(np.logspace(-5, 0))
)
plt.xscale("symlog", linthresh=1e-5, linscale=0.1)
genome_to_spgc_strain.sort_values("genotype_matching_positions", ascending=False)

In [None]:
plt.scatter(
    "strain_depth_sum",
    "genotype_dissimilarity_pc",
    data=genome_to_spgc_strain,
    alpha=0.4,
)

plt.yscale("symlog", linthresh=1e-4)
plt.xscale("symlog", linthresh=1e-2, linscale=0.1)

In [None]:
# data/group/xjin/r.proc.gene99_v15-v22-agg75.spgc_specgene-ref-t25-p95.SPECIES_DEPTH_BENCHMARK_GRID.flag
# data/group/xjin/r.proc.gene99_v15-v22-agg75.ACCURACY_BENCHMARK_GRID.flag
# data/group/xjin/r.proc.gene99_v15-v22-agg75.spgc-fit.STRAIN_META_BENCHMARK_GRID.flag


benchmark = {}
missing = []
unmatched = []
depth_meta = {}
spgc_qc = {}
for genome_id, d in genome.iterrows():
    species = d.species_id
    if species == "TODO":
        continue
    depth_path = f"data/group/xjin/species/sp-{species}/r.proc.gene99_v15-v22-agg75.spgc_specgene-ref-t25-p95.species_depth.tsv"
    _depth = pd.read_table(
        depth_path, names=["sample_id", "depth"], index_col="sample_id"
    ).depth
    depth_meta[species] = pd.Series(
        dict(species_depth_max=_depth.max(), species_depth_sum=_depth.sum())
    )

    # SPGC
    for unit in ["uhggtop", "eggnog", "cog"]:
        for tool in [
            "spgc-fit",
            # "spgc2-fit",
            # "nnmatched-m50",
            # "nnmatched-m10",
            # "nnmatched-m1",
            # "nnmatched-m0",
            "spgc-depth200",
        ]:
            accuracy_path = f"data/group/xjin/species/sp-{species}/r.proc.gene99_v15-v22-agg75.{tool}.{genome_id}.{unit}-reconstruction_accuracy.tsv"
            if not os.path.exists(accuracy_path):
                missing.append(accuracy_path)
                continue
            matched_strain = genome_to_spgc_strain.strain[genome_id]
            data = (
                pd.read_table(accuracy_path, index_col="strain")
                .assign(
                    species=species,
                    accuracy_path=accuracy_path,
                    strain=lambda x: x.index,
                )
                .sort_values("f1", ascending=False)
            )

            if matched_strain in data.index:
                benchmark[(genome_id, tool, unit, "match")] = data.loc[matched_strain]
            if data.shape[0] >= 1:
                benchmark[(genome_id, tool, unit, "top")] = data.iloc[0]
            if data.shape[0] >= 2:
                benchmark[(genome_id, tool, unit, "second")] = data.iloc[1]
        for tool in [
            "panphlan",
            # "spanda-s2",
            "spanda-s3",
            "spanda-s4",
            "spanda-s5",
            "spanda-s6",
        ]:
            accuracy_path = f"data/group/xjin/species/sp-{species}/r.proc.gene99_v15-v22-agg75.{tool}.{genome_id}.{unit}-reconstruction_accuracy.tsv"
            if not os.path.exists(accuracy_path):
                missing.append(accuracy_path)
                continue
            data = (
                pd.read_table(accuracy_path, index_col="strain")
                .assign(
                    species=species,
                    accuracy_path=accuracy_path,
                    strain=lambda x: x.index,
                )
                .sort_values("f1", ascending=False)
            )
            if data.shape[0] >= 1:
                benchmark[(genome_id, tool, unit, "top")] = data.iloc[0]
                benchmark[(genome_id, tool, unit, "match")] = data.iloc[0]
            if data.shape[0] >= 2:
                benchmark[(genome_id, tool, unit, "second")] = data.iloc[1]
            if data.shape[0] >= 3:
                benchmark[(genome_id, tool, unit, "third")] = data.iloc[2]

    spgc_qc_path = f"data/group/xjin/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_v15-v22-agg75.spgc-fit.strain_meta-s95-d100-a0-pos100-std25.tsv"
    if not os.path.exists(spgc_qc_path):
        continue
    _qc = pd.read_table(spgc_qc_path, index_col="genome_id")
    if matched_strain in _qc.index:
        spgc_qc[genome_id] = _qc.loc[matched_strain]

benchmark = pd.DataFrame(benchmark.values(), index=benchmark.keys()).rename_axis(
    ["genome_id", "tool", "unit", "match"]
)
depth_meta = pd.DataFrame(depth_meta.values(), index=depth_meta.keys()).rename_axis(
    "species"
)
spgc_qc = pd.DataFrame(spgc_qc.values(), index=spgc_qc.keys()).assign(
    qc_code=lambda x: x.apply(assign_qc_code, axis=1)
)

benchmark.sort_values(["species", "tool"])

In [None]:
# How many additional species had _no_ depth whatsoever?

genome.join(depth_meta, on="species_id")[lambda x: x.species_depth_sum == 0]

In [None]:
# And how many had any depth?

genome.join(depth_meta, on="species_id")[lambda x: x.species_depth_max.fillna(0) > 0]

In [None]:
genome_filt_list = list(
    genome.join(depth_meta, on="species_id")[
        lambda x: (x.species_id != "TODO") & (x.species_depth_sum > 0)
    ].index
)
species_filt_list = list(genome.loc[genome_filt_list].species_id.unique())

len(genome_filt_list), len(species_filt_list)

In [None]:
# How many actual SPGC fits are there?

benchmark.xs(
    ["spgc-fit", "eggnog", "match"], level=("tool", "unit", "match")
).f1.sort_values()

In [None]:
# Which theoretically estimate-able genomes have no StrainPGC results?

set(genome_filt_list) - set(
    benchmark.xs(["spgc-fit", "eggnog", "match"], level=("tool", "unit", "match")).index
)

In [None]:
# What's wrong with this strain?
genome.loc["Collinsella-aerofaciens-ATCC-25986_MAF-2"]

In [None]:
# Which zero-depth species have StrainPanDA results?

set(
    genome.join(depth_meta, on="species_id")[
        lambda x: x.species_depth_max.fillna(0) == 0
    ].index
) & set(
    benchmark.xs(
        ["spanda-s5", "eggnog", "match"], level=("tool", "unit", "match")
    ).index
)

In [None]:
# Which zero-depth species have PanPhlAn results?

set(
    genome.join(depth_meta, on="species_id")[
        lambda x: x.species_depth_max.fillna(0) == 0
    ].index
) & set(
    benchmark.xs(["panphlan", "eggnog", "match"], level=("tool", "unit", "match")).index
)

In [None]:
# How did the other tools perform on these?
benchmark.xs(["eggnog", "match"], level=("unit", "match")).loc[
    ["Roseburia-inulinivorans-DSM-16841_MAF-2"]
][["precision", "recall", "f1"]].unstack("tool")

In [None]:
bins = np.linspace(0, 1, num=50)
d = benchmark.xs(("eggnog", "top"), level=("unit", "match")).f1.unstack().fillna(0)
for tool in [
    "spgc-fit",
    # "nnmatched-m0", "nnmatched-m50",
    "spgc-depth200",
    # "spanda-s4",
]:
    plt.hist(d[tool], bins=bins, alpha=0.5, label=tool)
# plt.hist(d["nnmatched-m0"], bins=bins, alpha=0.5)
# plt.hist(d["nnmatched-m50"], bins=bins, alpha=0.5)
# plt.hist(d["spgc-depth200"], bins=bins, alpha=0.5)
# plt.hist(d["panphlan"], bins=bins, alpha=0.5)
# plt.hist(d["spanda-s2"], bins=bins, alpha=0.5)
# plt.hist(d["spanda-s3"], bins=bins, alpha=0.5)
# plt.hist(d["spanda-s4"], bins=bins, alpha=0.5)
plt.legend()
None

In [None]:
bins = np.linspace(0, 1, num=50)
d = benchmark.xs(("eggnog", "top"), level=("unit", "match")).f1.unstack().fillna(0)
for tool in [
    # "spgc-fit",
    # "nnmatched-m0", "nnmatched-m50",
    # "spgc-depth200",
    # "spanda-s2",
    "spanda-s3",
    "spanda-s4",
    "spanda-s5",
]:
    plt.hist(d[tool], bins=bins, alpha=0.5, label=tool)
# plt.hist(d["nnmatched-m0"], bins=bins, alpha=0.5)
# plt.hist(d["nnmatched-m50"], bins=bins, alpha=0.5)
# plt.hist(d["spgc-depth200"], bins=bins, alpha=0.5)
# plt.hist(d["panphlan"], bins=bins, alpha=0.5)
# plt.hist(d["spanda-s2"], bins=bins, alpha=0.5)
# plt.hist(d["spanda-s3"], bins=bins, alpha=0.5)
# plt.hist(d["spanda-s4"], bins=bins, alpha=0.5)
plt.legend()
None

In [None]:
d0 = genome_to_spgc_strain.join(
    benchmark.xs(("eggnog", "match"), level=("unit", "match")).f1.unstack()
).fillna(
    0
)  # .loc[genome_list]

plt.scatter(
    "spanda-s5",
    "spgc-fit",
    data=d0,
    alpha=0.4,
    norm=mpl.colors.SymLogNorm(linthresh=0.2),
    c="strain_depth_max",
)
plt.colorbar()
# plt.scatter(d1['spanda-s2'], d1['spgc-fit'])
plt.plot([0, 1], [0, 1], lw=1, linestyle="--", color="k")

In [None]:
d0 = genome_to_spgc_strain.join(
    benchmark.xs(("eggnog", "match"), level=("unit", "match")).f1.unstack()
).fillna(
    0
)  # .loc[genome_list]

plt.scatter(
    "spanda-s4",
    "spanda-s6",
    data=d0,
    alpha=0.4,
    norm=mpl.colors.SymLogNorm(linthresh=0.2),
    c="strain_depth_max",
)
plt.colorbar()
# plt.scatter(d1['spanda-s2'], d1['spgc-fit'])
plt.plot([0, 1], [0, 1], lw=1, linestyle="--", color="k")

In [None]:
benchmark.xs(
    ("eggnog", "match", "spanda-s5"), level=("unit", "match", "tool")
).sort_values("f1", ascending=False)

In [None]:
d0 = genome_to_spgc_strain.join(
    benchmark.xs(("spgc-fit", "eggnog"), level=("tool", "unit")).f1.unstack()
).fillna(0)


plt.scatter(
    "match",
    "top",
    data=d0,
    alpha=0.4,
    norm=mpl.colors.SymLogNorm(linthresh=0.2),
    # c="strain_depth_max",
)
# plt.colorbar()
# # plt.scatter(d1['spanda-s2'], d1['spgc-fit'])
# plt.plot([0, 1], [0, 1], lw=1, linestyle="--", color="k")

In [None]:
plt.scatter(
    "precision",
    "recall",
    data=benchmark.xs(("spgc-fit", "eggnog", "match"), level=("tool", "unit", "match")),
)
plt.plot([0, 1], [0, 1])

In [None]:
# Which genomes are failing and why?
(
    genome_to_spgc_strain.join(
        benchmark.xs(("eggnog", "match"), level=("unit", "match")).f1.unstack()
    )
    .fillna(0)[lambda x: x["spgc-fit"] == 0]
    .join(genome)
    # .join(depth_meta, on="species_id")
    .join(
        genome.species_id.value_counts().rename("num_strains_in_species"),
        on="species_id",
    )
)

In [None]:
# What's the point of benchmarking strains that have no or very low depth?
# that just even more heavily favors the tools that I'm taking "best hit" for.
d0 = (
    genome_to_spgc_strain.join(
        benchmark.xs(("eggnog", "match"), level=("unit", "match")).f1.unstack()
    )
    .fillna(0)
    .join(genome)
    # .join(depth_meta, on="species_id")
    [lambda x: x.species_depth_sum >= 0.5]
)
print(len(d0))

plt.scatter(
    "panphlan",
    "spgc-fit",
    data=d0,
    alpha=0.4,
    norm=mpl.colors.SymLogNorm(linthresh=0.2),
    c="strain_depth_max",
)
plt.colorbar()
# plt.scatter(d1['spanda-s2'], d1['spgc-fit'])
plt.plot([0, 1], [0, 1], lw=1, linestyle="--", color="k")

In [None]:
d0 = (
    genome_to_spgc_strain.join(
        benchmark.xs(("eggnog", "match"), level=("unit", "match")).f1.unstack()
    )
    .fillna(0)
    .join(genome)
    # .join(depth_meta, on="species_id")
    [lambda x: x.species_depth_sum >= 0.5]
)

plt.scatter(
    "panphlan",
    "spgc-fit",
    data=d0,
    alpha=0.4,
    norm=mpl.colors.SymLogNorm(linthresh=0.2),
    c="strain_depth_max",
)
plt.colorbar()
# plt.scatter(d1['spanda-s2'], d1['spgc-fit'])
plt.plot([0, 1], [0, 1], lw=1, linestyle="--", color="k")
plt.ylim(0.55, 1.05)
plt.xlim(0.55, 1.05)

In [None]:
d0 = (
    benchmark.xs(("eggnog", "match"), level=("unit", "match"))
    .f1.unstack()
    .fillna(0)
    .join(
        genome_to_spgc_strain[
            [
                "genotype_dissimilarity",
                "strain_depth_sum",
                "strain_depth_max",
            ]
        ]
    )
    .fillna({"genotype_dissimilarity": 1.0})
    .fillna(0)
)

plt.scatter(
    "genotype_dissimilarity",
    "spgc-fit",
    data=d0,
    alpha=0.4,
    c="strain_depth_max",
    norm=mpl.colors.SymLogNorm(linthresh=0.1),
)
plt.colorbar()
# plt.colorbar()
# plt.scatter(d1['spanda-s2'], d1['spgc-fit'])
# plt.plot([0, 1], [0, 1], lw=1, linestyle="--", color="k")
plt.xscale("symlog", linthresh=1e-5, linscale=0.1)

In [None]:
d0 = (
    benchmark.xs(("eggnog", "match"), level=("unit", "match"))
    .f1.unstack()
    .fillna(0)
    .join(
        genome_to_spgc_strain[
            [
                "genotype_dissimilarity",
                "strain_depth_sum",
                "strain_depth_max",
            ]
        ]
    )
    .fillna({"genotype_dissimilarity": 1.0})
    .fillna(0)
)

plt.scatter(
    "strain_depth_sum",
    "genotype_dissimilarity",
    data=d0,
    alpha=0.4,
    c="spgc-fit",
    vmin=0,
    vmax=1,
)
plt.colorbar()
# plt.scatter(d1['spanda-s2'], d1['spgc-fit'])
# plt.plot([0, 1], [0, 1], lw=1, linestyle="--", color="k")
plt.yscale("symlog", linthresh=1e-4)
plt.xscale("symlog", linthresh=1e-2, linscale=0.1)

In [None]:
d0 = (
    benchmark.xs(("eggnog", "match"), level=("unit", "match"))
    .f1.unstack()
    .fillna(0)
    .join(
        genome_to_spgc_strain[
            [
                "genotype_dissimilarity",
                "strain_depth_sum",
                "strain_depth_max",
            ]
        ]
    )
    .fillna({"genotype_dissimilarity": 1.0})
    .fillna(0)
)

plt.scatter(
    "strain_depth_sum",
    "genotype_dissimilarity",
    data=d0,
    alpha=0.4,
    c="spgc-fit",
    vmin=0,
    vmax=1,
)
plt.colorbar()
# plt.scatter(d1['spanda-s2'], d1['spgc-fit'])
# plt.plot([0, 1], [0, 1], lw=1, linestyle="--", color="k")
plt.yscale("symlog", linthresh=1e-4)
plt.xscale("symlog", linthresh=1e-2, linscale=0.1)

In [None]:
multi_genome_species = idxwhere(
    strain_match.join(genome.species_id)
    .reset_index()[["species_id", "genome_id"]]
    .drop_duplicates()
    .species_id.value_counts()
    > 1
)
multi_genome_species

In [None]:
d = (
    genome.join(genome_to_spgc_strain)[
        lambda x: (x.species_id != "TODO")
        & (x.species_depth_sum > 0)
        & (x.species_depth_max > 0)
        # & ~x.species_id.isin(multi_genome_species)
    ]
    .join(
        benchmark.xs(("spgc-fit", "eggnog", "match"), level=("tool", "unit", "match")),
        rsuffix="_",
    )
    .fillna({"precision": 0, "recall": 0, "f1": 0, "jaccard": 0})
)
d

# Statistics

## How many strains/species/phyla are analyzed in the benchmark?

In [None]:
# Full list

len(genome_list), len(species_list), len(
    species_taxonomy.loc[species_list].p__.unique()
)

In [None]:
# Filtered list
len(genome_filt_list), len(species_filt_list), len(
    species_taxonomy.loc[species_filt_list].p__.unique()
)

In [None]:
# TODO: Which species were excluded?
#  - For too few GT-Pro positions?
#  - For too little depth?
#  - For other reasons???
#  - What was supposed to be included but basically failed, and why?

## StrainPGC Performance

In [None]:
tool = "spgc-fit"
unit = "eggnog"
match = "match"

d = (
    genome_to_spgc_strain.loc[genome_filt_list]
    .join(benchmark.xs((unit, match), level=("unit", "match")), rsuffix="_")[
        ["precision", "recall", "f1"]
    ]
    .unstack("tool")
    .xs(tool, level="tool", axis="columns")
    .fillna(0)
    .join(genome_to_spgc_strain.species_gene_frac)
)

# plt.scatter('precision', "recall", data=d, c='species_gene_frac')
# plt.colorbar()
# plt.plot([0, 1], [0, 1])
print(d[["precision", "recall", "f1"]].quantile([0.25, 0.5, 0.75]))
# print(d[d.species_gene_frac > 0.9][["precision", "recall", "f1", "jaccard"]].quantile([0.25, 0.5, 0.75]))

## How does it compare to other tools?

In [None]:
d = (
    genome_to_spgc_strain.loc[genome_filt_list]
    .join(benchmark.xs((unit, match), level=("unit", "match")), rsuffix="_")
    .f1.unstack("tool")
    .fillna(0)
)


print((-(d.subtract(d["spgc-fit"], axis=0))).quantile([0.25, 0.5, 0.75]))

In [None]:
unit = "eggnog"
match = "match"

qc_code_palette = {
    0: "grey",
    1: "lightgreen",
    2: "lightblue",
    3: "lightsalmon",
    -1: "white",
}

d0 = benchmark.xs((unit, match), level=("unit", "match"))

_tool_comparison_order = [
    "panphlan",
    "spanda-s6",
    # "spgc-depth200",
]
_score_order = ["precision", "recall", "f1"]
fig, axs = plt.subplots(
    len(_score_order),
    len(_tool_comparison_order),
    figsize=(2 * len(_tool_comparison_order), 2 * len(_score_order)),
    sharex=True,
    sharey=True,
)
y_tool = "spgc-fit"
for score, ax_row in zip(_score_order, axs):
    d1 = (
        d0[score]
        .unstack(fill_value=0)
        .join(genome_to_spgc_strain)
        .assign(
            c=lambda x: x.qc_code.astype(int).map(
                qc_code_palette,
            )
        )
        .sort_values("qc_code")
    )
    print(
        score, "SPGC IQR:", (d1[y_tool]).quantile([0.25, 0.5, 0.75]).round(3).tolist()
    )
    for x_tool, ax in zip(_tool_comparison_order, ax_row):
        print(
            "compared to: {}, {}, {:.1g}".format(
                x_tool,
                (d1[y_tool] - d1[x_tool])
                .quantile([0.05, 0.25, 0.5, 0.75, 0.95])
                .round(3)
                .tolist(),
                sp.stats.wilcoxon(d1[x_tool], d1[y_tool]).pvalue,
            )
        )
        ax.scatter(
            x_tool,
            y_tool,
            data=d1,
            c="c",
            lw=0.5,
            edgecolor="k",
            s=15,
            label="__nolegend__",
        )
        ax.plot([0, 1], [0, 1], lw=1, linestyle="--", color="k")
    print()

for _score_order, ax in zip(_score_order, axs[:, 0]):
    ax.set_ylabel(_score_order)
    ax.set_yticks([0, 0.5, 1])
    ax.set_ylim(-0.05, 1.05)

for x_tool, ax in zip(_tool_comparison_order, axs[-1, :]):
    ax.set_xlabel(x_tool)
    ax.set_xticks([0, 0.5, 1])
    ax.set_xlim(-0.05, 1.05)

fig, ax = plt.subplots(figsize=(2, 1.25))
for i in range(-1, 4):
    ax.scatter(
        [],
        [],
        label=qc_code_meaning[i],
        color=qc_code_palette[i],
        lw=0.5,
        edgecolor="k",
        s=15,
    )
ax.legend(markerscale=2, loc="lower right")
lib.plot.hide_axes_and_spines(ax=ax)

In [None]:
unit = "eggnog"
match = "match"

qc_code_palette = {
    0: "grey",
    1: "lightgreen",
    2: "lightblue",
    3: "lightsalmon",
    -1: "white",
}

d0 = benchmark.xs((unit, match), level=("unit", "match"))

_tool_comparison_order = [
    "panphlan",
    "spanda-s6",
    # "spgc-depth200",
]
_score_order = ["precision", "recall", "f1"]
fig, axs = plt.subplots(
    len(_score_order),
    len(_tool_comparison_order),
    figsize=(2 * len(_tool_comparison_order), 2 * len(_score_order)),
    sharex=True,
    sharey=True,
    gridspec_kw=dict(hspace=0.1, wspace=0.1),
)
y_tool = "spgc-fit"
for score, ax_row in zip(_score_order, axs):
    d1 = (
        d0[score]
        .unstack(fill_value=0)
        .join(genome_to_spgc_strain)
        .assign(
            c=lambda x: x.qc_code.astype(int).map(
                qc_code_palette,
            )
        )
        .sort_values("qc_code")
    )
    print(
        score, "SPGC IQR:", (d1[y_tool]).quantile([0.25, 0.5, 0.75]).round(3).tolist()
    )
    for x_tool, ax in zip(_tool_comparison_order, ax_row):
        print(
            "compared to: {}, {}, {:.1g}".format(
                x_tool,
                (d1[y_tool] - d1[x_tool])
                .quantile([0.05, 0.25, 0.5, 0.75, 0.95])
                .round(3)
                .tolist(),
                sp.stats.wilcoxon(d1[x_tool], d1[y_tool]).pvalue,
            )
        )
        ax.scatter(
            x_tool,
            y_tool,
            data=d1,
            # c="c",
            color='grey',
            lw=0.25,
            edgecolor="k",
            s=15,
            label="__nolegend__",
        )
        ax.plot([0, 1], [0, 1], lw=1, linestyle="--", color="k")
    print()

for _score_order, ax in zip(_score_order, axs[:, 0]):
    ax.set_ylabel(_score_order)
    ax.set_yticks([0, 0.5, 1])
    ax.set_ylim(-0.05, 1.05)

for x_tool, ax in zip(_tool_comparison_order, axs[-1, :]):
    ax.set_xlabel(x_tool)
    ax.set_xticks([0, 0.5, 1])
    ax.set_xlim(-0.05, 1.05)

# fig, ax = plt.subplots(figsize=(2, 1.25))
# for i in range(-1, 4):
#     ax.scatter(
#         [],
#         [],
#         label=qc_code_meaning[i],
#         color=qc_code_palette[i],
#         lw=0.5,
#         edgecolor="k",
#         s=15,
#     )
# ax.legend(markerscale=2, loc="lower right")
# lib.plot.hide_axes_and_spines(ax=ax)

In [None]:
unit = "eggnog"
match = "match"

qc_code_palette = {
    0: "grey",
    1: "lightgreen",
    2: "lightblue",
    3: "lightsalmon",
    -1: "grey",
}

d0 = benchmark.xs((unit, match), level=("unit", "match"))

_tool_comparison_order = [
    "panphlan",
    "spanda-s6",
    # "spgc-depth200",
]
_score_order = ["precision", "recall", "f1"]
fig, axs = plt.subplots(
    len(_score_order),
    len(_tool_comparison_order),
    figsize=(2 * len(_tool_comparison_order), 2 * len(_score_order)),
    sharex=True,
    sharey=True,
    gridspec_kw=dict(hspace=0.1, wspace=0.1),
)
y_tool = "spgc-fit"
for score, ax_row in zip(_score_order, axs):
    d1 = (
        d0[score]
        .unstack(fill_value=0)
        .join(genome_to_spgc_strain)
        .assign(
            c=lambda x: x.qc_code.astype(int).map(
                qc_code_palette,
            )
        )
        .sort_values("qc_code")
    )
    print(
        score, "SPGC IQR:", (d1[y_tool]).quantile([0.25, 0.5, 0.75]).round(3).tolist()
    )
    for x_tool, ax in zip(_tool_comparison_order, ax_row):
        print(
            "compared to: {}, {}, {:.1g}".format(
                x_tool,
                (d1[y_tool] - d1[x_tool])
                .quantile([0.05, 0.25, 0.5, 0.75, 0.95])
                .round(3)
                .tolist(),
                sp.stats.wilcoxon(d1[x_tool], d1[y_tool]).pvalue,
            )
        )
        ax.scatter(
            x_tool,
            y_tool,
            data=d1,
            # c="c",
            lw=.8,
            edgecolors="c",
            facecolor='none',
            s=30,
            alpha=0.75,
            label="__nolegend__",
        )
        ax.plot([0, 1], [0, 1], lw=1, linestyle="--", color="k")
    print()

for _score_order, ax in zip(_score_order, axs[:, 0]):
    ax.set_ylabel(_score_order)
    ax.set_yticks([0, 0.5, 1])
    ax.set_ylim(-0.05, 1.05)

for x_tool, ax in zip(_tool_comparison_order, axs[-1, :]):
    ax.set_xlabel(x_tool)
    ax.set_xticks([0, 0.5, 1])
    ax.set_xlim(-0.05, 1.05)

fig, ax = plt.subplots(figsize=(2, 1.25))
for i in range(-1, 4):
    ax.scatter(
        [],
        [],
        label=qc_code_meaning[i],
        # color=qc_code_palette[i],
        lw=1,
        edgecolor=qc_code_palette[i],
        facecolor='none',
        s=30,
    )
ax.legend(markerscale=2, loc="lower right")
lib.plot.hide_axes_and_spines(ax=ax)

In [None]:
unit = "eggnog"
match = "match"

# qc_code_palette = {
#     0: "grey",
#     1: "lightgreen",
#     2: "lightblue",
#     3: "lightsalmon",
#     -1: "grey",
# }

cmap='Greys'
norm=mpl.colors.LogNorm(vmin=0.2, vmax=50)

d0 = benchmark.xs((unit, match), level=("unit", "match"))

_tool_comparison_order = [
    "panphlan",
    "spanda-s6",
    # "spgc-depth200",
]
_score_order = ["precision", "recall", "f1"]
fig, axs = plt.subplots(
    len(_score_order),
    len(_tool_comparison_order),
    figsize=(2 * len(_tool_comparison_order), 2 * len(_score_order)),
    sharex=True,
    sharey=True,
    gridspec_kw=dict(hspace=0.1, wspace=0.1),
    squeeze=False,
)
y_tool = "spgc-fit"
for score, ax_row in zip(_score_order, axs):
    d1 = (
        d0[score]
        .unstack(fill_value=0)
        .join(genome_to_spgc_strain)
        .assign(
            c=lambda x: x.qc_code.astype(int).map(
                qc_code_palette,
            )
        )
        .sort_values("qc_code")
    )
    print(
        score, "SPGC IQR:", (d1[y_tool]).quantile([0.25, 0.5, 0.75]).round(3).tolist()
    )
    for x_tool, ax in zip(_tool_comparison_order, ax_row):
        print(
            "compared to: {}, {}, {:.1g}".format(
                x_tool,
                (d1[y_tool] - d1[x_tool])
                .quantile([0.05, 0.25, 0.5, 0.75, 0.95])
                .round(3)
                .tolist(),
                sp.stats.wilcoxon(d1[x_tool], d1[y_tool]).pvalue,
            )
        )
        art = ax.hist2d(
            x_tool,
            y_tool,
            data=d1,
            # # c="c",
            # lw=1,
            # edgecolor="grey",
            # facecolor='none',
            # s=30,
            # alpha=0.75,
            # label="__nolegend__",
            cmap=cmap,
            norm=norm,
            bins=np.linspace(0, 1, num=21),
        )
        ax.plot([0, 1], [0, 1], lw=0.5, linestyle="--", color="k")
    print()

for _score_order, ax in zip(_score_order, axs[:, 0]):
    ax.set_ylabel(_score_order)
    ax.set_yticks([0, 0.5, 1])
    ax.set_ylim(-0.05, 1.05)

for x_tool, ax in zip(_tool_comparison_order, axs[-1, :]):
    ax.set_xlabel(x_tool)
    ax.set_xticks([0, 0.5, 1])
    ax.set_xlim(-0.05, 1.05)

fig, ax = plt.subplots(figsize=(0.1, 1.25))
fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), cax=ax)
ax.set_yscale('log', subs=[])
ax.set_yticks([1, 10, 50])
# for i in range(-1, 4):
#     ax.scatter(
#         [],
#         [],
#         label=qc_code_meaning[i],
#         # color=qc_code_palette[i],
#         lw=1,
#         edgecolor=qc_code_palette[i],
#         facecolor='none',
#         s=30,
#     )
# ax.legend(markerscale=2, loc="lower right")
# lib.plot.hide_axes_and_spines(ax=ax)

In [None]:
unit = "eggnog"
match = "match"

qc_code_palette = {
    0: "grey",
    1: "lightgreen",
    2: "lightsalmon",
    3: "violet",
    -1: "white",
}

d0 = benchmark.xs((unit, match), level=("unit", "match"))

_tool_comparison_order = [
    "panphlan",
    "spanda-s4",
    "spanda-s5",
    "spgc-fit",
]
_score_order = ["precision", "recall", "f1"]
fig, axs = plt.subplots(
    len(_score_order),
    len(_tool_comparison_order),
    figsize=(3 * len(_tool_comparison_order), 3 * len(_score_order)),
    sharex=True,
    sharey=True,
)
y_tool = "spanda-s6"
for score, ax_row in zip(_score_order, axs):
    d1 = (
        d0[score]
        .unstack(fill_value=0)
        .join(genome_to_spgc_strain)
        .assign(
            c=lambda x: x.qc_code.fillna(-1)
            .astype(int)
            .map(
                qc_code_palette,
            )
        )
        .sort_values("qc_code")
    )
    print(
        score, "SPGC IQR:", (d1[y_tool]).quantile([0.25, 0.5, 0.75]).round(3).tolist()
    )
    for x_tool, ax in zip(_tool_comparison_order, ax_row):
        print(
            "compared to: {}, {}, {:.1g}".format(
                x_tool,
                (d1[y_tool] - d1[x_tool])
                .quantile([0.05, 0.25, 0.5, 0.75, 0.95])
                .round(3)
                .tolist(),
                sp.stats.wilcoxon(d1[x_tool], d1[y_tool]).pvalue,
            )
        )
        ax.scatter(
            x_tool,
            y_tool,
            data=d1,
            c="c",
            lw=0.5,
            edgecolor="k",
            s=15,
            label="__nolegend__",
        )
        ax.plot([0, 1], [0, 1], lw=1, linestyle="--", color="k")
    print()

for _score_order, ax in zip(_score_order, axs[:, 0]):
    ax.set_ylabel(_score_order)
    ax.set_yticks([0, 0.5, 1])
    ax.set_ylim(-0.05, 1.05)

for x_tool, ax in zip(_tool_comparison_order, axs[-1, :]):
    ax.set_xlabel(x_tool)
    ax.set_xticks([0, 0.5, 1])
    ax.set_xlim(-0.05, 1.05)

fig, ax = plt.subplots(figsize=(2, 1.25))
for i in range(-1, 4):
    ax.scatter(
        [],
        [],
        label=qc_code_meaning[i],
        color=qc_code_palette[i],
        lw=0.5,
        edgecolor="k",
        s=15,
    )
ax.legend(markerscale=2, loc="lower right")
lib.plot.hide_axes_and_spines(ax=ax)

In [None]:
# FIXME: Load data as in the main-text figure.

unit = "eggnog"
match = "match"

_tool_list = [
    "spgc-fit",
    "panphlan",
    "spanda-s6",
    "spgc-depth200",
    # "nnmatched-m50"
]
_score_list = ["precision", "recall", "f1"]
xvar = "strain_depth_max"

# _tool_palette = lib.plot.construct_ordered_palette(_tool_list)
_tool_palette = lib.plot.construct_ordered_palette_from_list(
    _tool_list, colors=["tab:blue", "tab:green", "tab:orange", "tab:purple"]
)


d0 = (
    benchmark.xs((unit, match), level=("unit", "match"))[_score_list]
    .reindex(product(genome_filt_list, _tool_list))
    .fillna(0)
    .join(genome)
    # [lambda x: ~x.species_id.isin(multi_genome_species)]
    .join(depth_meta, on="species_id")
    .join(genome_to_spgc_strain, rsuffix="_")
)

fig, axs = plt.subplots(
    len(_score_list),
    figsize=(5, 3 * len(_score_list)),
    sharex=True,
    sharey=True,
)

for _tool in _tool_list:
    d1 = d0.xs(_tool, level="tool")
    for _score, ax in zip(_score_list, axs.flatten()):
        fit = smf.ols(f'{_score} ~ cr({xvar}, 5, constraints="center")', data=d1).fit()
        d2 = d1.assign(
            lowess_y=lambda d: sm.nonparametric.lowess(
                d[_score], d[xvar], it=10, frac=1 / 3, return_sorted=False
            ),
            spline_y=lambda d: fit.predict(pd.DataFrame({xvar: d[xvar]})),
            xvar_rolling_average=lambda d: d.sort_values(xvar)
            .rolling(window=10)[xvar]
            .mean(),
            score_rolling_average=lambda d: d.sort_values(xvar)
            .rolling(window=10)[_score]
            .mean(),
        )
        smoothed = pd.DataFrame({xvar: np.logspace(-0.5, 4, num=100)}).assign(
            spline_y=lambda d: fit.predict(pd.DataFrame({xvar: d[xvar]})),
            lowess_y=lambda d: sm.nonparametric.lowess(
                d2[_score],
                d2[xvar],
                xvals=d[xvar],
                it=10,
                frac=1 / 3,
                return_sorted=False,
            ),
        )
        ax.scatter(
            xvar,
            _score,
            data=d2,
            label=_tool,
            s=8,
            alpha=0.5,
            facecolor="none",
            color=_tool_palette[_tool],
        )
        print(_tool, _score, sp.stats.spearmanr(d2[xvar], d2[_score]))
        ax.plot(
            "xvar_rolling_average",
            "score_rolling_average",
            data=d2.sort_values(xvar),
            label="__nolegend__",
            lw=1,
            linestyle="-",
            color=_tool_palette[_tool],
        )

for _score, ax in zip(_score_list, axs.flatten()):
    ax.set_ylabel(_score)

ax.set_xscale("symlog", linthresh=1e-1, linscale=0.1)

pad = 1e-2
# ax.set_yscale(
#     "function",
#     functions=(lambda x: -np.log(1 + pad - x), lambda y: (1 + pad - np.exp(-y))),
# )
# ax.set_yticks(np.unique([0.0, 0.5, 0.75, 0.9, 0.95, 0.99]))
ax.set_ylim(-0.05, 1.05)

ax.legend(bbox_to_anchor=(1, 1))

print(len(d2))

In [None]:
species_depth = []
_missing_species = []

for species in tqdm(species_list):
    inpath = f"data/group/xjin/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_v15-v22-agg75.spgc-fit.species_depth.tsv"
    if not os.path.exists(inpath):
        _missing_species.append(species)
        continue
    data = pd.read_table(inpath, names=["sample", "depth"]).assign(species=species)
    species_depth.append(data)
species_depth = (
    pd.concat(species_depth)
    .set_index(["sample", "species"])
    .depth.unstack(fill_value=0)
)

print(
    len(_missing_species),
    "out of",
    len(species_list),
    "species are missing.",
)

In [None]:
strain_depth = []
missing_files = []
for species_id in species_depth.columns:
    path = f"data/group/xjin/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.comm.tsv"
    try:
        d = (
            pd.read_table(path, index_col=["sample", "strain"])
            .squeeze()
            .unstack()
            # .rename(columns=str, index=lambda x: "CF_" + str(int(x.split("_")[1])))
            # .rename({'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap.
        )
    except FileNotFoundError:
        missing_files.append(path)
        d = pd.DataFrame([])
    _keep_strains = idxwhere(d.sum() > 0.05)
    assert d.index.isin(species_depth.index).all()
    d = d.reindex(index=species_depth.index, columns=_keep_strains, fill_value=0)
    d = d.assign(__other=lambda x: 1 - x.sum(1)).rename(columns={"__other": -1})
    d[d < 0] = 0
    d = d.divide(d.sum(1), axis=0)
    d = d.multiply(species_depth[species_id], axis=0)
    d = d.rename(columns=lambda s: f"{species_id}_{s}")
    strain_depth.append(d)
strain_depth = pd.concat(strain_depth, axis=1)
strain_rabund = strain_depth.divide(strain_depth.sum(1), axis=0)
len(species_depth.columns), len(missing_files)

In [None]:
xjin_strain_depth = strain_depth[lambda x: x.index.str.startswith("xjin_")]
xjin_strain_depth = xjin_strain_depth[idxwhere((xjin_strain_depth.sum() > 0.01))]
plt.hist(np.log10(xjin_strain_depth.sum()), bins=20)

## Number of benchmark samples

In [None]:
xjin_strain_depth.shape[0]

In [None]:
xjin_sotu_strain_list = idxwhere(xjin_strain_depth.sum() > 1)
other_strain_depth = xjin_strain_depth.drop(columns=xjin_sotu_strain_list).sum(1)

d = xjin_strain_depth[xjin_sotu_strain_list].assign(other=other_strain_depth)
sns.clustermap(
    d.T,
    norm=mpl.colors.SymLogNorm(linthresh=0.1),
    metric="cosine",
    figsize=(5, 8),
    xticklabels=0,
    yticklabels=0,
)

In [None]:
xjin_species_depth = species_depth[lambda x: x.index.str.startswith("xjin_")].rename(
    columns=str
)
xjin_species_depth = xjin_species_depth[idxwhere((xjin_species_depth.sum() > 0.01))]
plt.hist(np.log10(xjin_species_depth.sum()), bins=20)

## Depth Distribution of Species in Benchmark Samples

In [None]:
d0 = xjin_species_depth.reindex(species_filt_list, axis=1, fill_value=0)
print(d0.max().quantile([0.25, 0.5, 0.75]))
_num_genomes_per_species = (
    genome_to_spgc_strain.assign(species=genome.species_id)
    .species.value_counts()
    .reindex(
        d0.columns, fill_value=0
    )  # TODO: Do we only want to consider these 85 species, or the larger list?
)
_num_genomes_order = range(1, _num_genomes_per_species.max() + 1)
_num_genomes_palette = lib.plot.construct_ordered_palette(
    _num_genomes_order, cm="plasma", other="white"
)
_num_genomes_size_and_marker_palette = dict(
    zip(_num_genomes_order, [(25, "<"), (20, "s"), (100, "p"), (150, "*")])
)
# _xjin_species_order = d.median().sort_values(ascending=False).index

d1 = pd.DataFrame(
    dict(
        _max=d0.max(),
        _sum=d0.sum(),
        num_species_genomes=_num_genomes_per_species,
        p__=d0.columns.map(species_taxonomy.p__),
    )
).assign(
    depth_in_other_samples=lambda x: x._sum - x._max,
    # num_genomes_c=lambda x: x.num_species_genomes.map(_num_genomes_palette),
    phylum_c=lambda x: x.p__.map(phylum_palette),
    # num_genomes_s=lambda x: 50 * np.sqrt(x.num_species_genomes)
)

fig, ax = plt.subplots(figsize=(3, 4.5))

for num_species_genomes in _num_genomes_order:
    d2 = d1[lambda x: x.num_species_genomes == num_species_genomes]
    markersize, markershape = _num_genomes_size_and_marker_palette[num_species_genomes]
    ax.scatter(
        "_max",
        "_sum",
        s=markersize,
        facecolors="phylum_c",
        lw=0.5,
        data=d2,
        marker=markershape,
        edgecolor="black",
        label="__nolegend__",
        alpha=0.8,
    )
    ax.scatter(
        [],
        [],
        s=markersize,
        facecolors="grey",
        lw=0.5,
        data=d2,
        marker=markershape,
        edgecolor="black",
        label=num_species_genomes,
    )
ax.plot([3e-2, 1e5], [3e-2, 1e5], lw=1, linestyle="--", color="k")
ax.set_yscale("log")
ax.set_xscale("log")
ax.set_xlim(3e-2, 1e3)
ax.set_ylim(3e-2, 1e5)
ax.set_aspect(1)
ax.set_xlabel("Deepest Sample Depth")
ax.set_ylabel("Total Depth")
ax.legend(loc="lower right", title="Num. Strains")

fig.savefig("fig/xjin_benchmark_species_details.svg", bbox_inches="tight")

In [None]:
d1.groupby('p__').num_species_genomes.sum()

In [None]:
d0 = xjin_species_depth.reindex(species_filt_list, axis=1, fill_value=0)
print(d0.max().quantile([0.25, 0.5, 0.75]))
_num_genomes_per_species = (
    genome_to_spgc_strain.assign(species=genome.species_id)
    .species.value_counts()
    .reindex(
        d0.columns, fill_value=0
    )  # TODO: Do we only want to consider these 85 species, or the larger list?
)
_num_genomes_order = range(1, _num_genomes_per_species.max() + 1)
_num_genomes_palette = lib.plot.construct_ordered_palette(
    _num_genomes_order, cm="plasma", other="white"
)
_num_genomes_size_and_marker_palette = dict(
    zip(_num_genomes_order, [(15, "o"), (30, "s"), (70, "p"), (120, "*")])
)
# _xjin_species_order = d.median().sort_values(ascending=False).index

d1 = pd.DataFrame(
    dict(
        _max=d0.max(),
        _sum=d0.sum(),
        num_species_genomes=_num_genomes_per_species,
        p__=d0.columns.map(species_taxonomy.p__),
    )
).assign(
    depth_in_other_samples=lambda x: x._sum - x._max,
    # num_genomes_c=lambda x: x.num_species_genomes.map(_num_genomes_palette),
    phylum_c=lambda x: x.p__.map(phylum_palette),
    # num_genomes_s=lambda x: 50 * np.sqrt(x.num_species_genomes)
)

fig, ax = plt.subplots(figsize=(3, 4.5))

for num_species_genomes in _num_genomes_order:
    d2 = d1[lambda x: x.num_species_genomes == num_species_genomes]
    markersize, markershape = _num_genomes_size_and_marker_palette[num_species_genomes]
    ax.scatter(
        "_sum",
        "_max",
        s=markersize,
        facecolors="phylum_c",
        lw=0.5,
        data=d2,
        marker=markershape,
        edgecolor="black",
        label="__nolegend__",
        alpha=0.85,
    )
    ax.scatter(
        [],
        [],
        s=markersize,
        facecolors="grey",
        lw=0.5,
        data=d2,
        marker=markershape,
        edgecolor="black",
        label=num_species_genomes,
    )
ax.plot([3e-2, 1e5], [3e-2, 1e5], lw=1, linestyle="--", color="k")
ax.set_yscale("log", subs=[])
ax.set_xscale("log", subs=[])
ax.set_ylim(3e-2, 1e3)
ax.set_xlim(3e-2, 1e5)
ax.set_aspect(1)
ax.set_xlabel("Total Depth")
ax.set_ylabel("Max Depth")

leg = ax.legend(loc="upper left", title="Num. Strains", labelspacing=0.1, frameon=False)
leg._legend_box.align = "left"

# fig.savefig("fig/xjin_benchmark_species_details.svg", bbox_inches="tight")

## Number of strains per species after filtering

In [None]:
genome.loc[genome_filt_list].species_id.value_counts()

## Quality Control Stats

In [None]:
unit = "eggnog"
match = "match"
tool = "spgc-fit"

d = (
    benchmark.xs((unit, match, tool), level=("unit", "match", "tool"))
    .reindex(genome_filt_list)
    .join(genome_to_spgc_strain, rsuffix="_")
    .assign(log_species_depth_sum=lambda x: np.log10(x.species_depth_sum))[
        lambda x: x.species_depth_sum > 0
    ]
)

print(
    d.shape
)  # NOTE: There is one species (Collinsiella?) where strain deconvolution failed, and which therefore has no SPGC strains.

fig, axs = plt.subplots(
    3, 3, squeeze=False, figsize=(10, 10), sharex="col", sharey="row"
)

for (x, y), ax in zip(
    product(
        [
            "log_selected_gene_depth_ratio_std",
            "species_gene_frac",
            "log_species_depth_sum",
        ],
        ["precision", "recall", "f1"],
    ),
    axs.T.flatten(),
):
    print((x, y), sp.stats.spearmanr(d[x], d[y]))
    ax.scatter(x, y, data=d, s=15, alpha=0.5)
    ax.set_xlabel(x)
    ax.set_ylabel(y)

# (d1[lambda x: x[('precision', 'spgc-fit')] < 0.7])

In [None]:
unit = "eggnog"
match = "match"
tool = "spgc-fit"

d = (
    benchmark.xs((unit, match, tool), level=("unit", "match", "tool"))[
        ["precision", "recall", "f1"]
    ]
    .reindex(genome_filt_list, fill_value=0)
    .join(genome_to_spgc_strain, rsuffix="_")
    .assign(log_species_depth_sum=lambda x: np.log10(x.species_depth_sum))[
        lambda x: x.species_depth_sum > 0
    ]
)
low_quality_strains = idxwhere(~d.passes_filter)

len(low_quality_strains), d.loc[low_quality_strains].f1.quantile(
    [0.25, 0.5, 0.75]
), d.drop(low_quality_strains).f1.quantile([0.25, 0.5, 0.75])

In [None]:
sp.stats.mannwhitneyu(d.loc[low_quality_strains].f1, d.drop(low_quality_strains).f1)

In [None]:
bins = np.linspace(0, 1)

plt.hist(d.loc[low_quality_strains].f1, bins=bins)
plt.hist(d.drop(low_quality_strains).f1, bins=bins, alpha=0.5)

In [None]:
d.loc[low_quality_strains].precision.quantile([0.25, 0.5, 0.75]), d.drop(
    low_quality_strains
).precision.quantile([0.25, 0.5, 0.75])

In [None]:
d.loc[low_quality_strains].recall.quantile([0.25, 0.5, 0.75]), d.drop(
    low_quality_strains
).recall.quantile([0.25, 0.5, 0.75])