## Preamble

### Project Template

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

### Imports

In [None]:
import os
import subprocess
import sys
import time
from datetime import datetime
from glob import glob
from itertools import chain, product
from tempfile import mkstemp

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import statsmodels.api as sm
import statsmodels.formula.api as smf
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

## Style

In [None]:
mpl.scale.register_scale(mpl.scale.FuncScale)

In [None]:
genome = pd.read_table("meta/genome.tsv", index_col="genome_id")
assert genome.index.is_unique
genome

In [None]:
species_list = list(
    genome[
        lambda x: (x.species_id != "TODO") & (x.genome_path != "")
    ].species_id.unique()
)
len(species_list)

In [None]:
xjin_sample_list = pd.read_table(
    "meta/XJIN_BENCHMARK/mgen.tsv", index_col="mgen_id"
).index.to_list()

In [None]:
species_group = (
    pd.read_table("meta/species_group.tsv")[
        lambda x: x.species_group_id == "xjin_ucfmt_hmp2"
    ]
    .species_id.astype(str)
    .to_list()
)
species_group[:5]

In [None]:
# data/group/XJIN_BENCHMARK/r.proc.gtpro.sfacts-fit.gene99_new-v22-agg75.spgc_specgene-ref-trim25-p95.STRAIN_MATCH_BENCHMARK_GRID.flag

strain_match = {}
for genome_id, d in genome.iterrows():
    species = d.species_id
    if species == "TODO":
        continue
    strain_match_path = f"data/group/XJIN_BENCHMARK/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_new-v22-agg75.spgc_specgene-ref-t25-p95.{genome_id}.geno_matching_stats.tsv"
    if os.path.exists(strain_match_path):
        strain_match[genome_id] = pd.read_table(
            strain_match_path, index_col=["genome_id", "strain"]
        ).assign(
            genotype_matching_positions=lambda x: (1 - x.genotype_dissimilarity)
            * x.num_geno_positions_compared,
            genotype_dissimilarity_pc=lambda x: x.genotype_dissimilarity
            + (1 / x.num_geno_positions_compared),
            strain_match_path=strain_match_path,
        )
strain_match = (
    pd.concat(strain_match.values()).reset_index().set_index(["genome_id", "strain"])
)
strain_match

In [None]:
# data/group/XJIN_BENCHMARK/r.proc.gene99_new-v22-agg75.ACCURACY_BENCHMARK_GRID.flag

# NOTE: We match strains only from among those with "strain-pure-samples" in the xjin set
# and then picking the lowest genotype_dissimilarity.
# TODO: Consider also allowing strains to have strain pure samples
# outside the xjin set, but also be detected at appreciable relative abundances
# inside the xjin samples?
genome_to_spgc_strain = (
    strain_match[lambda x: x.strain_depth_max > 0]
    .reset_index()
    .sort_values(
        ["genotype_dissimilarity_pc", "strain_depth_sum"], ascending=(True, False)
    )
    .groupby("genome_id")
    .head(1)
    .set_index("genome_id")
    .reindex(genome.index)
    .fillna(
        {
            "strain": -1,
            "genotype_dissimilarity": 1.0,
            "strain_depth_sum": 0,
            "strain_depth_max": 0,
        }
    )
    .assign(strain=lambda x: x.strain.astype(int))
)

plt.hist(
    genome_to_spgc_strain.genotype_dissimilarity_pc, bins=[0] + list(np.logspace(-5, 0))
)
plt.xscale("symlog", linthresh=1e-5, linscale=0.1)
genome_to_spgc_strain.sort_values("genotype_matching_positions", ascending=False)

In [None]:
plt.scatter(
    "strain_depth_sum",
    "genotype_dissimilarity_pc",
    data=genome_to_spgc_strain,
    alpha=0.4,
)

plt.yscale("symlog", linthresh=1e-4)
plt.xscale("symlog", linthresh=1e-2, linscale=0.1)

In [None]:
benchmark = {}
missing = []
unmatched = []
depth_meta = {}
for genome_id, d in genome.iterrows():
    species = d.species_id
    if species == "TODO":
        continue
    depth_path = f"data/group/xjin_ucfmt_hmp2/species/sp-{species}/r.proc.gene99_new-v22-agg75.spgc_specgene-ref-t25-p95.species_depth.tsv"
    _depth = pd.read_table(
        depth_path, names=["sample_id", "depth"], index_col="sample_id"
    ).depth.loc[xjin_sample_list]
    depth_meta[species] = pd.Series(
        dict(species_depth_max=_depth.max(), species_depth_sum=_depth.sum())
    )

    # SPGC
    for unit in ["uhggtiles", "uhggtop", "eggnog"]:
        for tool in [
            "spgc-fit",
            # "nnmatched-m50",
            # "nnmatched-m10",
            # "nnmatched-m1",
            # "nnmatched-m0",
            "spgc-depth250",
        ]:
            accuracy_path = f"data/group/XJIN_BENCHMARK/species/sp-{species}/r.proc.gene99_new-v22-agg75.{tool}.{genome_id}.{unit}-reconstruction_accuracy.tsv"
            if not os.path.exists(accuracy_path):
                missing.append(accuracy_path)
                continue
            matched_strain = genome_to_spgc_strain.strain[genome_id]
            data = (
                pd.read_table(accuracy_path, index_col="strain")
                .assign(species=species, accuracy_path=accuracy_path, strain=lambda x: x.index)
                .sort_values("f1", ascending=False)
            )

            if matched_strain in data.index:
                benchmark[(genome_id, tool, unit, "match")] = data.loc[matched_strain]
            if data.shape[0] >= 1:
                benchmark[(genome_id, tool, unit, "top")] = data.iloc[0]
            if data.shape[0] >= 2:
                benchmark[(genome_id, tool, unit, "second")] = data.iloc[1]
        for tool in [
            "panphlan",
            # "spanda-s2",
            # "spanda-s3",
            "spanda-s4",
        ]:
            accuracy_path = f"data/group/XJIN_BENCHMARK/species/sp-{species}/r.proc.gene99_new-v22-agg75.{tool}.{genome_id}.{unit}-reconstruction_accuracy.tsv"
            if not os.path.exists(accuracy_path):
                missing.append(accuracy_path)
                continue
            data = (
                pd.read_table(accuracy_path, index_col="strain")
                .assign(species=species, accuracy_path=accuracy_path, strain=lambda x: x.index)
                .sort_values("f1", ascending=False)
            )
            if data.shape[0] >= 1:
                benchmark[(genome_id, tool, unit, "top")] = data.iloc[0]
                benchmark[(genome_id, tool, unit, "match")] = data.iloc[0]
            if data.shape[0] >= 2:
                benchmark[(genome_id, tool, unit, "second")] = data.iloc[1]
            if data.shape[0] >= 3:
                benchmark[(genome_id, tool, unit, "third")] = data.iloc[2]
benchmark = pd.DataFrame(benchmark.values(), index=benchmark.keys()).rename_axis(
    ["genome_id", "tool", "unit", "match"]
)
depth_meta = pd.DataFrame(depth_meta.values(), index=depth_meta.keys()).rename_axis(
    "species"
)

benchmark.sort_values(["species", "tool"])

In [None]:
bins = np.linspace(0, 1, num=50)
d = benchmark.xs(("eggnog", "top"), level=("unit", "match")).f1.unstack().fillna(0)
for tool in [
    "spgc-fit",
    # "nnmatched-m0", "nnmatched-m50",
    "spgc-depth250",
]:
    plt.hist(d[tool], bins=bins, alpha=0.5, label=tool)
# plt.hist(d["nnmatched-m0"], bins=bins, alpha=0.5)
# plt.hist(d["nnmatched-m50"], bins=bins, alpha=0.5)
# plt.hist(d["spgc-depth250"], bins=bins, alpha=0.5)
# plt.hist(d["panphlan"], bins=bins, alpha=0.5)
# plt.hist(d["spanda-s2"], bins=bins, alpha=0.5)
# plt.hist(d["spanda-s3"], bins=bins, alpha=0.5)
# plt.hist(d["spanda-s4"], bins=bins, alpha=0.5)
plt.legend()
None

In [None]:
d0 = genome_to_spgc_strain.join(
    benchmark.xs("eggnog", level="unit").f1.unstack()
).fillna(0)

plt.scatter(
    "panphlan",
    "spgc-fit",
    data=d0,
    alpha=0.4,
    norm=mpl.colors.SymLogNorm(linthresh=0.2),
    c="strain_depth_max",
)
plt.colorbar()
# plt.scatter(d1['spanda-s2'], d1['spgc-fit'])
plt.plot([0, 1], [0, 1], lw=1, linestyle="--", color="k")

In [None]:
d0 = genome_to_spgc_strain.join(
    benchmark.xs(("spgc-fit", "eggnog"), level=("tool", "unit")).f1.unstack()
).fillna(0)


plt.scatter(
    "match",
    "top",
    data=d0,
    alpha=0.4,
    norm=mpl.colors.SymLogNorm(linthresh=0.2),
    # c="strain_depth_max",
)
# plt.colorbar()
# # plt.scatter(d1['spanda-s2'], d1['spgc-fit'])
# plt.plot([0, 1], [0, 1], lw=1, linestyle="--", color="k")

In [None]:
# Which genomes are failing and why?
(
    genome_to_spgc_strain.join(benchmark.xs("eggnog", level="unit").f1.unstack())
    .fillna(0)[lambda x: x["spgc-fit"] == 0]
    .join(genome)
    # .join(depth_meta, on="species_id")
    .join(
        genome.species_id.value_counts().rename("num_strains_in_species"),
        on="species_id",
    )
)

In [None]:
# What's the point of benchmarking strains that have no or very low depth?
# that just even more heavily favors the tools that I'm taking "best hit" for.
d0 = (
    genome_to_spgc_strain.join(benchmark.xs("eggnog", level="unit").f1.unstack())
    .fillna(0)
    .join(genome)
    # .join(depth_meta, on="species_id")
    [lambda x: x.species_depth_sum >= 0.5]
)
print(len(d0))

plt.scatter(
    "panphlan",
    "spgc-fit",
    data=d0,
    alpha=0.4,
    norm=mpl.colors.SymLogNorm(linthresh=0.2),
    c="strain_depth_max",
)
plt.colorbar()
# plt.scatter(d1['spanda-s2'], d1['spgc-fit'])
plt.plot([0, 1], [0, 1], lw=1, linestyle="--", color="k")

In [None]:
d0 = (
    genome_to_spgc_strain.join(benchmark.xs("eggnog", level="unit").f1.unstack())
    .fillna(0)
    .join(genome)
    # .join(depth_meta, on="species_id")
    [lambda x: x.species_depth_sum >= 0.5]
)

plt.scatter(
    "panphlan",
    "spgc-fit",
    data=d0,
    alpha=0.4,
    norm=mpl.colors.SymLogNorm(linthresh=0.2),
    c="strain_depth_max",
)
plt.colorbar()
# plt.scatter(d1['spanda-s2'], d1['spgc-fit'])
plt.plot([0, 1], [0, 1], lw=1, linestyle="--", color="k")
plt.ylim(0.55, 1.05)
plt.xlim(0.55, 1.05)

In [None]:
d0 = (
    benchmark.f1.xs("eggnog", level="unit")
    .unstack()
    .fillna(0)
    .join(
        genome_to_spgc_strain[
            [
                "genotype_dissimilarity",
                "strain_depth_sum",
                "strain_depth_max",
            ]
        ]
    )
    .fillna({"genotype_dissimilarity": 1.0})
    .fillna(0)
)

plt.scatter(
    "genotype_dissimilarity",
    "spgc-fit",
    data=d0,
    alpha=0.4,
    c="strain_depth_max",
    norm=mpl.colors.SymLogNorm(linthresh=0.1),
)
plt.colorbar()
# plt.colorbar()
# plt.scatter(d1['spanda-s2'], d1['spgc-fit'])
# plt.plot([0, 1], [0, 1], lw=1, linestyle="--", color="k")
plt.xscale("symlog", linthresh=1e-5, linscale=0.1)

In [None]:
d0 = (
    benchmark.xs("eggnog", level="unit")
    .f1.unstack()
    .fillna(0)
    .join(
        genome_to_spgc_strain[
            [
                "genotype_dissimilarity",
                "strain_depth_sum",
                "strain_depth_max",
            ]
        ]
    )
    .fillna({"genotype_dissimilarity": 1.0})
    .fillna(0)
)

plt.scatter(
    "strain_depth_sum",
    "genotype_dissimilarity",
    data=d0,
    alpha=0.4,
    c="spgc-fit",
    vmin=0,
    vmax=1,
)
plt.colorbar()
# plt.scatter(d1['spanda-s2'], d1['spgc-fit'])
# plt.plot([0, 1], [0, 1], lw=1, linestyle="--", color="k")
plt.yscale("symlog", linthresh=1e-4)
plt.xscale("symlog", linthresh=1e-2, linscale=0.1)

In [None]:
multi_genome_species = idxwhere(
    strain_match.join(genome.species_id)
    .reset_index()[["species_id", "genome_id"]]
    .drop_duplicates()
    .species_id.value_counts()
    > 1
)
multi_genome_species

In [None]:
d = (
    genome.join(genome_to_spgc_strain)[
        lambda x: (x.species_id != "TODO")
        & (x.species_depth_sum > 0)
        & (x.species_depth_max > 0)
        # & ~x.species_id.isin(multi_genome_species)
    ]
    .join(benchmark.xs(("spgc-fit", "eggnog"), level=("tool", "unit")), rsuffix="_")
    .fillna({"precision": 0, "recall": 0, "f1": 0, "jaccard": 0})
)
d

In [None]:
d = (
    genome.join(genome_to_spgc_strain)[
        lambda x: (x.species_id != "TODO")
        & (x.species_depth_sum > 0)
        & (x.species_depth_max > 0)
        # & ~x.species_id.isin(multi_genome_species)
    ]
    .join(benchmark.xs(("spgc-fit", "eggnog"), level=("tool", "unit")), rsuffix="_")
    .fillna({"precision": 0, "recall": 0, "f1": 0, "jaccard": 0})
)

d[["precision", "recall", "f1", "jaccard"]].quantile([0.25, 0.5, 0.75])

In [None]:
d = (
    genome.join(genome_to_spgc_strain)[
        lambda x: (x.species_id != "TODO")
        & (x.species_depth_sum > 0)
        & (x.species_depth_max > 0)
        # & ~x.species_id.isin(multi_genome_species)
    ]
    .join(benchmark.xs("eggnog", level="unit"), rsuffix="_")
    .fillna({"precision": 0, "recall": 0, "f1": 0})
    .f1.unstack("tool")
)

(-(d.subtract(d["spgc-fit"], axis=0)) / d).quantile([0.25, 0.5, 0.75])

In [None]:
d = (
    genome.join(genome_to_spgc_strain)[
        lambda x: (x.species_id != "TODO")
        & (x.species_depth_sum > 0)
        & (x.species_depth_max > 0)
        # & ~x.species_id.isin(multi_genome_species)
    ]
    .join(benchmark.xs("eggnog", level="unit"), rsuffix="_")
    .fillna({"precision": 0, "recall": 0, "f1": 0})
    .recall.unstack("tool")
)

(-(d.subtract(d["spgc-fit"], axis=0)) / d).quantile([0.25, 0.5, 0.75])

In [None]:
unit = "eggnog"

d1 = (
    benchmark.xs(unit, level="unit")
    .join(genome.join(depth_meta, on="species_id"))[
        lambda x: x.species_depth_sum > 0.1
    ][["f1", "precision", "recall", "jaccard"]]
    .unstack()
    .rename_axis(columns=["score", "tool"])
    .fillna(0)
)

### Figure A ###
_tool_comparison_order = ["panphlan", "spanda-s4"]
_score_order = ["precision", "recall", "f1"]
fig, axs = plt.subplots(
    len(_score_order),
    len(_tool_comparison_order),
    figsize=(2.5 * len(_tool_comparison_order) + 0.5, 2.5 * len(_score_order)),
    sharex=True,
    sharey=True,
)
y = "spgc-fit"
nbins = 20
left_bound = 0.0
bins = [0] + list(np.linspace(left_bound, 1, num=nbins + 1)[1:])
for score, ax_row in zip(_score_order, axs):
    d2 = d1.xs(score, level="score", axis="columns")
    print(score, "SPGC IQR:", (d2[y]).quantile([0.25, 0.5, 0.75]).round(3).tolist())
    for x, ax in zip(_tool_comparison_order, ax_row):
        *_, cbar_artist = ax.hist2d(
            x,
            y,
            data=d2,
            bins=bins,
            cmin=1,
            norm=mpl.colors.PowerNorm(1 / 1, vmin=0, vmax=21),
            cmap="magma_r",
        )
        ax.plot([left_bound, 1], [left_bound, 1], lw=1, linestyle="--", color="k")
        ax.set_aspect(1)
        print(
            "compared to:",
            x,
            (d2[y] - d2[x]).quantile([0.05, 0.25, 0.5, 0.75, 0.95]).round(3).tolist(),
            sp.stats.wilcoxon(d2[x], d2[y]).pvalue,
        )
        ax.set_xticks([0, 0.6, 0.8, 1.0])
        ax.set_yticks([0, 0.6, 0.8, 1.0])

ax.set_xlim(left_bound, 1)
ax.set_ylim(left_bound, 1)

for score, ax_row in zip(_score_order, axs):
    ax_row[0].set_ylabel(score)
    # *_, artist = ax_row[-1].hist2d(x, y, data=d1.head(0), bins=np.linspace(0, 1, num=21), norm=mpl.colors.PowerNorm(1/2, vmin=0, vmax=30), cmap='Blues')
    # fig.colorbar(artist, cax=ax_row[-1])

for tool_comparison, ax_col in zip(_tool_comparison_order, axs.T):
    ax_col[0].set_title(tool_comparison)

fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.025, 0.67])
fig.colorbar(
    cbar_artist, cax=cbar_ax, ticks=[1, 10, 20, 30, 40, 50], label="count strains"
)
# fig.tight_layout(rect=(0, 0, 0.85, 0.67))
# ax.set_xlabel(x)
# ax.set_ylabel(y)

In [None]:
unit = "eggnog"

d1 = (
    benchmark.xs(unit, level="unit")
    .join(genome.join(depth_meta, on="species_id"))[
        lambda x: x.species_depth_sum > 0.1
    ][["f1", "precision", "recall", "jaccard"]]
    .unstack()
    .rename_axis(columns=["score", "tool"])
    .fillna(0)
)

### Figure A ###
_tool_comparison_order = ["panphlan", "spanda-s4"]

_score_order = ["precision", "recall", "f1", "jaccard"]
fig, axs = plt.subplots(
    len(_score_order),
    len(_tool_comparison_order),
    figsize=(2.5 * len(_tool_comparison_order), 2.5 * len(_score_order)),
    sharex=True,
    sharey=True,
)
y = "spgc-fit"
for score, ax_row in zip(_score_order, axs):
    d2 = d1.xs(score, level="score", axis="columns").join(
        genome.join(genome_to_spgc_strain).join(
            depth_meta, on="species_id", rsuffix="_"
        )
    )
    print(score, "SPGC IQR:", (d2[y]).quantile([0.25, 0.5, 0.75]).round(3).tolist())
    for x, ax in zip(_tool_comparison_order, ax_row):
        ax.scatter(
            x,
            y,
            data=d2,
            s=5,
            c="strain_depth_max",
        )
        ax.plot([left_bound, 1], [left_bound, 1], lw=1, linestyle="--", color="k")
        ax.set_aspect(1)
        print(
            "compared to:",
            x,
            (d2[y] - d2[x]).quantile([0.05, 0.25, 0.5, 0.75, 0.95]).round(3).tolist(),
            sp.stats.wilcoxon(d2[x], d2[y]).pvalue,
        )
        ax.set_xticks([0, 0.6, 0.8, 1.0])
        ax.set_yticks([0, 0.6, 0.8, 1.0])


pad = 1e-2
ax.set_yscale(
    "function",
    functions=(lambda x: -np.log(1 + pad - x), lambda y: (1 + pad - np.exp(-y))),
)
ax.set_yticks(np.unique([0.0, 0.5, 0.75, 0.9, 0.95, 0.99]))
ax.set_ylim(-0.1, 1.00)

ax.set_xscale(
    "function",
    functions=(lambda x: -np.log(1 + pad - x), lambda y: (1 + pad - np.exp(-y))),
)
ax.set_xticks(np.unique([0.0, 0.5, 0.75, 0.9, 0.95, 0.99]))
ax.set_xlim(-0.1, 1.00)

# ax.set_xlim(left_bound, 1)
# ax.set_ylim(left_bound, 1)

for score, ax_row in zip(_score_order, axs):
    ax_row[0].set_ylabel(score)
    # *_, artist = ax_row[-1].hist2d(x, y, data=d1.head(0), bins=np.linspace(0, 1, num=21), norm=mpl.colors.PowerNorm(1/2, vmin=0, vmax=30), cmap='Blues')
    # fig.colorbar(artist, cax=ax_row[-1])

for tool_comparison, ax_col in zip(_tool_comparison_order, axs.T):
    ax_col[0].set_title(tool_comparison)
    lib.plot.rotate_xticklabels(ax=ax_col[-1], rotation=45)

In [None]:
unit = "eggnog"

d1 = (
    benchmark.xs(unit, level="unit")
    .join(genome.join(depth_meta, on="species_id"))[
        lambda x: x.species_depth_sum > 0.1
    ][["f1", "precision", "recall"]]
    .unstack()
    .rename_axis(columns=["score", "tool"])
    .fillna(0)
)

### Figure A ###
_tool_comparison_order = [
    "panphlan",
    "spanda-s4",
    # "nnmatched-m50",
    # "nnmatched-m10",
    # "nnmatched-m1",
    # "nnmatched-m0",
    "spgc-depth250",
]
_score_order = ["precision", "recall", "f1"]
fig, axs = plt.subplots(
    len(_score_order),
    len(_tool_comparison_order),
    figsize=(2.5 * len(_tool_comparison_order) + 0.5, 2.5 * len(_score_order)),
    sharex=True,
    sharey=True,
)
y = "spgc-fit"
nbins = 20
left_bound = 0.0
bins = [0] + list(np.linspace(left_bound, 1, num=nbins + 1)[1:])
for score, ax_row in zip(_score_order, axs):
    d2 = d1.xs(score, level="score", axis="columns")
    print(score, "SPGC IQR:", (d2[y]).quantile([0.25, 0.5, 0.75]).round(3).tolist())
    for x, ax in zip(_tool_comparison_order, ax_row):
        *_, cbar_artist = ax.hist2d(
            x,
            y,
            data=d2,
            bins=bins,
            cmin=1,
            norm=mpl.colors.PowerNorm(1 / 1, vmin=0, vmax=21),
            cmap="magma_r",
        )
        ax.plot([left_bound, 1], [left_bound, 1], lw=1, linestyle="--", color="k")
        ax.set_aspect(1)
        print(
            "compared to:",
            x,
            (d2[y] - d2[x]).quantile([0.05, 0.25, 0.5, 0.75, 0.95]).round(3).tolist(),
            sp.stats.wilcoxon(d2[x], d2[y]).pvalue,
        )
        ax.set_xticks([0, 0.6, 0.8, 1.0])
        ax.set_yticks([0, 0.6, 0.8, 1.0])

ax.set_xlim(left_bound, 1)
ax.set_ylim(left_bound, 1)

for score, ax_row in zip(_score_order, axs):
    ax_row[0].set_ylabel(score)
    # *_, artist = ax_row[-1].hist2d(x, y, data=d1.head(0), bins=np.linspace(0, 1, num=21), norm=mpl.colors.PowerNorm(1/2, vmin=0, vmax=30), cmap='Blues')
    # fig.colorbar(artist, cax=ax_row[-1])

for tool_comparison, ax_col in zip(_tool_comparison_order, axs.T):
    ax_col[0].set_title(tool_comparison)

fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.025, 0.67])
fig.colorbar(
    cbar_artist, cax=cbar_ax, ticks=[1, 10, 20, 30, 40, 50], label="count strains"
)
# fig.tight_layout(rect=(0, 0, 0.85, 0.67))
# ax.set_xlabel(x)
# ax.set_ylabel(y)

In [None]:
unit = "eggnog"

d1 = (
    benchmark.xs(unit, level="unit")
    .join(genome.join(depth_meta, on="species_id"))[
        lambda x: x.species_depth_sum > 0.1
    ][["f1", "precision", "recall"]]
    .unstack()
    .rename_axis(columns=["score", "tool"])
    .fillna(0)
)

d1.f1[["spgc-fit", "spgc-depth250"]].sort_values("spgc-fit").join(genome)

In [None]:
unit = "eggnog"

_genome_list = idxwhere(genome_to_spgc_strain.species_depth_sum > 0)

_tool_list = [
    "spgc-fit",
    "panphlan",
    "spanda-s4",
    # "nnmatched-m50"
]
_score_list = ["precision", "recall", "f1"]
xvar = "strain_depth_sum"

# _tool_palette = lib.plot.construct_ordered_palette(_tool_list)
_tool_palette = lib.plot.construct_ordered_palette_from_list(
    _tool_list, colors=["tab:blue", "tab:green", "tab:orange", "tab:purple"]
)


d0 = (
    benchmark.xs((unit, "match"), level=("unit", "match"))[_score_list]
    .reindex(product(_genome_list, _tool_list))
    .fillna(0)
    .join(genome)
    # [lambda x: ~x.species_id.isin(multi_genome_species)]
    .join(depth_meta, on="species_id")
    .join(genome_to_spgc_strain, rsuffix="_")
)

fig, axs = plt.subplots(
    len(_score_list),
    figsize=(5, 3 * len(_score_list)),
    sharex=True,
    sharey=True,
)

for _tool in _tool_list:
    d1 = d0.xs(_tool, level="tool")
    for _score, ax in zip(_score_list, axs.flatten()):
        fit = smf.ols(f'{_score} ~ cr({xvar}, 5, constraints="center")', data=d1).fit()
        d2 = d1.assign(
            lowess_y=lambda d: sm.nonparametric.lowess(
                d[_score], d[xvar], it=10, frac=1 / 3, return_sorted=False
            ),
            spline_y=lambda d: fit.predict(pd.DataFrame({xvar: d[xvar]})),
            xvar_rolling_average=lambda d: d.sort_values(xvar)
            .rolling(window=10)[xvar]
            .mean(),
            score_rolling_average=lambda d: d.sort_values(xvar)
            .rolling(window=10)[_score]
            .mean(),
        )
        smoothed = pd.DataFrame({xvar: np.logspace(-0.5, 4, num=100)}).assign(
            spline_y=lambda d: fit.predict(pd.DataFrame({xvar: d[xvar]})),
            lowess_y=lambda d: sm.nonparametric.lowess(
                d2[_score],
                d2[xvar],
                xvals=d[xvar],
                it=10,
                frac=1 / 3,
                return_sorted=False,
            ),
        )
        ax.scatter(
            xvar,
            _score,
            data=d2,
            label=_tool,
            s=8,
            alpha=0.5,
            facecolor="none",
            color=_tool_palette[_tool],
        )
        ax.plot(
            "xvar_rolling_average",
            "score_rolling_average",
            data=d2.sort_values(xvar),
            label="__nolegend__",
            lw=1,
            linestyle="-",
            color=_tool_palette[_tool],
        )

for _score, ax in zip(_score_list, axs.flatten()):
    ax.set_ylabel(_score)

ax.set_xscale("symlog", linthresh=1e-1, linscale=0.1)

pad = 1e-1
ax.set_yscale(
    "function",
    functions=(lambda x: -np.log(1 + pad - x), lambda y: (1 + pad - np.exp(-y))),
)
ax.set_yticks(np.unique([0.0, 0.5, 0.75, 0.9, 0.95, 0.99]))
ax.set_ylim(-0.1, 1.00)

ax.legend(bbox_to_anchor=(1, 1))

In [None]:
d0 = (
    benchmark.xs((unit, "match", "spgc-fit"), level=("unit", "match", "tool"))[_score_list]
    .fillna(0)
    .reindex(genome_to_spgc_strain.index, fill_value=0)
    .join(genome_to_spgc_strain, rsuffix="_")
)

d0

In [None]:
d = benchmark.xs("eggnog", level="unit").f1.unstack().fillna(0)
plt.scatter("spanda-s2", "spanda-s4", data=d)
plt.plot([0, 1], [0, 1], lw=1, linestyle="--", color="k")
idxwhere(d["spanda-s4"] - d["spanda-s2"] > 0.03)

In [None]:
unit = "eggnog"

d1 = (
    benchmark.xs(unit, level="unit")
    .join(genome.join(depth_meta, on="species_id"))[
        lambda x: x.species_id.isin(multi_genome_species)
    ][["f1", "precision", "recall"]]
    .unstack()
    .rename_axis(columns=["score", "tool"])
    .fillna(0)
)

### Figure A ###
_tool_comparison_order = [
    "panphlan",
    "spanda-s4",
    "nnmatched-m50",
    "nnmatched-m10",
    "nnmatched-m1",
    "nnmatched-m0",
    "spgc-depth250",
]
_score_order = ["precision", "recall", "f1"]
fig, axs = plt.subplots(
    len(_score_order),
    len(_tool_comparison_order),
    figsize=(2.5 * len(_tool_comparison_order) + 0.5, 2.5 * len(_score_order)),
    sharex=True,
    sharey=True,
)
y = "spgc-fit"
nbins = 20
left_bound = 0.0
bins = [0] + list(np.linspace(left_bound, 1, num=nbins + 1)[1:])
for score, ax_row in zip(_score_order, axs):
    d2 = d1.xs(score, level="score", axis="columns")
    print(score, "SPGC IQR:", (d2[y]).quantile([0.25, 0.5, 0.75]).round(3).tolist())
    for x, ax in zip(_tool_comparison_order, ax_row):
        *_, cbar_artist = ax.hist2d(
            x,
            y,
            data=d2,
            bins=bins,
            cmin=1,
            norm=mpl.colors.PowerNorm(1 / 1, vmin=0, vmax=21),
            cmap="magma_r",
        )
        ax.plot([left_bound, 1], [left_bound, 1], lw=1, linestyle="--", color="k")
        ax.set_aspect(1)
        print(
            "compared to:",
            x,
            (d2[y] - d2[x]).quantile([0.05, 0.25, 0.5, 0.75, 0.95]).round(3).tolist(),
            sp.stats.wilcoxon(d2[x], d2[y]).pvalue,
        )
        ax.set_xticks([0, 0.6, 0.8, 1.0])
        ax.set_yticks([0, 0.6, 0.8, 1.0])

ax.set_xlim(left_bound, 1)
ax.set_ylim(left_bound, 1)

for score, ax_row in zip(_score_order, axs):
    ax_row[0].set_ylabel(score)
    # *_, artist = ax_row[-1].hist2d(x, y, data=d1.head(0), bins=np.linspace(0, 1, num=21), norm=mpl.colors.PowerNorm(1/2, vmin=0, vmax=30), cmap='Blues')
    # fig.colorbar(artist, cax=ax_row[-1])

for tool_comparison, ax_col in zip(_tool_comparison_order, axs.T):
    ax_col[0].set_title(tool_comparison)

fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.025, 0.67])
fig.colorbar(
    cbar_artist, cax=cbar_ax, ticks=[1, 10, 20, 30, 40, 50], label="count strains"
)
# fig.tight_layout(rect=(0, 0, 0.85, 0.67))
# ax.set_xlabel(x)
# ax.set_ylabel(y)

In [None]:
unit = "eggnog"

d1 = (
    benchmark.xs(unit, level="unit")
    .join(genome.join(depth_meta, on="species_id"))[
        lambda x: x.species_id.isin(multi_genome_species)
    ][["f1", "precision", "recall"]]
    .unstack()
    .rename_axis(columns=["score", "tool"])
    .fillna(0)
)

### Figure A ###
_tool_comparison_order = [
    "panphlan",
    "spanda-s4",
    "nnmatched-m50",
    "nnmatched-m10",
    "nnmatched-m1",
    "nnmatched-m0",
    "spgc-depth250",
]
_score_order = ["precision", "recall", "f1"]
fig, axs = plt.subplots(
    len(_score_order),
    len(_tool_comparison_order),
    figsize=(2.5 * len(_tool_comparison_order), 2.5 * len(_score_order)),
    sharex=True,
    sharey=True,
)
y = "spgc-fit"
for score, ax_row in zip(_score_order, axs):
    d2 = d1.xs(score, level="score", axis="columns")
    print(score, "SPGC IQR:", (d2[y]).quantile([0.25, 0.5, 0.75]).round(3).tolist())
    for x, ax in zip(_tool_comparison_order, ax_row):
        ax.scatter(
            x,
            y,
            data=d2,
            s=5,
        )
        ax.plot([left_bound, 1], [left_bound, 1], lw=1, linestyle="--", color="k")
        ax.set_aspect(1)
        print(
            "compared to:",
            x,
            (d2[y] - d2[x]).quantile([0.05, 0.25, 0.5, 0.75, 0.95]).round(3).tolist(),
            sp.stats.wilcoxon(d2[x], d2[y]).pvalue,
        )
        ax.set_xticks([0, 0.6, 0.8, 1.0])
        ax.set_yticks([0, 0.6, 0.8, 1.0])


pad = 1e-2
ax.set_yscale(
    "function",
    functions=(lambda x: -np.log(1 + pad - x), lambda y: (1 + pad - np.exp(-y))),
)
ax.set_yticks(np.unique([0.0, 0.5, 0.9, 0.99, 1.00]))
ax.set_ylim(-0.1, 1.00)

ax.set_xscale(
    "function",
    functions=(lambda x: -np.log(1 + pad - x), lambda y: (1 + pad - np.exp(-y))),
)
ax.set_xticks(np.unique([0.0, 0.5, 0.9, 0.99, 1.00]))
ax.set_xlim(-0.1, 1.00)

# ax.set_xlim(left_bound, 1)
# ax.set_ylim(left_bound, 1)

for score, ax_row in zip(_score_order, axs):
    ax_row[0].set_ylabel(score)
    # *_, artist = ax_row[-1].hist2d(x, y, data=d1.head(0), bins=np.linspace(0, 1, num=21), norm=mpl.colors.PowerNorm(1/2, vmin=0, vmax=30), cmap='Blues')
    # fig.colorbar(artist, cax=ax_row[-1])

for tool_comparison, ax_col in zip(_tool_comparison_order, axs.T):
    ax_col[0].set_title(tool_comparison)