## Preamble

### Project Template

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

### Imports

In [None]:
import os
import subprocess
import sys
import time
from datetime import datetime
from glob import glob
from itertools import chain, product
from tempfile import mkstemp

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import statsmodels.api as sm
import statsmodels.formula.api as smf
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

## Style

In [None]:
mpl.scale.register_scale(mpl.scale.FuncScale)

In [None]:
genome_list = idxwhere(pd.read_table("meta/genome_group.tsv", index_col="genome_id").genome_group_id == "xjin")
len(genome_list)

In [None]:
genome = pd.read_table("meta/genome.tsv", index_col="genome_id").loc[genome_list]
assert genome.index.is_unique

In [None]:
species_list = list(
    genome.loc[genome_list][
        lambda x: (x.species_id != "TODO") & (x.genome_path != "")
    ].species_id.unique()
)
len(species_list)

In [None]:
species_group = (
    pd.read_table("meta/species_group.tsv")[
        lambda x: x.species_group_id == "xjin"
    ]
    .species_id.astype(str)
    .to_list()
)
species_group[:5]

In [None]:
genome[lambda x: x.species_id == '100003']

In [None]:
pd.read_table(strain_meta_path)

In [None]:
# data/group/xjin/r.proc.gtpro.sfacts-fit.gene99_v15-v22-agg75.spgc_specgene-ref-t25-p95.STRAIN_MATCH_BENCHMARK_GRID.flag

strain_match = {}
for genome_id, d in genome.loc[genome_list].iterrows():
    species = d.species_id
    if species == "TODO":
        continue
    strain_match_path = f"data/group/xjin/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_v15-v22-agg75.spgc_specgene-ref-t25-p95.{genome_id}.geno_matching_stats.tsv"
    strain_meta_path = f"data/group/xjin/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_v15-v22-agg75.spgc-fit.strain_meta-s90-d100-a1-pos100.tsv"
    if os.path.exists(strain_match_path):
        strain_meta = pd.read_table(strain_meta_path, index_col='genome_id').rename_axis('strain')
        strain_match[genome_id] = pd.read_table(
            strain_match_path, index_col=["genome_id", "strain"]
        ).assign(
            genotype_matching_positions=lambda x: (1 - x.genotype_dissimilarity)
            * x.num_geno_positions_compared,
            genotype_dissimilarity_pc=lambda x: x.genotype_dissimilarity
            + (1 / x.num_geno_positions_compared),
            strain_match_path=strain_match_path,
        ).join(strain_meta, on='strain')
strain_match = (
    pd.concat(strain_match.values()).reset_index().set_index(["genome_id", "strain"])
)
strain_match

In [None]:
# NOTE: We match strains only from among those with "strain-pure-samples" in the xjin set
# and then picking the lowest genotype_dissimilarity.
# TODO: Consider also allowing strains to have strain pure samples
# outside the xjin set, but also be detected at appreciable relative abundances
# inside the xjin samples?
genome_to_spgc_strain = (
    strain_match[lambda x: x.strain_depth_max > 0]
    .reset_index()
    .sort_values(
        ["genotype_dissimilarity_pc", "strain_depth_sum"], ascending=(True, False)
    )
    .groupby("genome_id")
    .head(1)
    .set_index("genome_id")
    .reindex(genome.index)
    .fillna(
        {
            "strain": -1,
            "genotype_dissimilarity": 1.0,
            "strain_depth_sum": 0,
            "strain_depth_max": 0,
        }
    )
    .assign(strain=lambda x: x.strain.astype(int))
)

plt.hist(
    genome_to_spgc_strain.genotype_dissimilarity_pc, bins=[0] + list(np.logspace(-5, 0))
)
plt.xscale("symlog", linthresh=1e-5, linscale=0.1)
genome_to_spgc_strain.sort_values("genotype_matching_positions", ascending=False)

In [None]:
plt.scatter(
    "strain_depth_sum",
    "genotype_dissimilarity_pc",
    data=genome_to_spgc_strain,
    alpha=0.4,
)

plt.yscale("symlog", linthresh=1e-4)
plt.xscale("symlog", linthresh=1e-2, linscale=0.1)

In [None]:
# data/group/xjin/r.proc.gene99_v15-v22-agg75.spgc_specgene-ref-t25-p95.species_depth.tsv.SELECT_SPECIES.flag
# data/group/xjin/r.proc.gene99_v15-v22-agg75.ACCURACY_BENCHMARK_GRID.flag
# data/group/xjin/r.proc.gtpro.sfacts-fit.gene99_v15-v22-agg75.spgc-fit.strain_meta-hmp2-s90-d100-a1-pos100.tsv.SELECT_SPECIES.flag


benchmark = {}
missing = []
unmatched = []
depth_meta = {}
spgc_qc = {}
for genome_id, d in genome.iterrows():
    species = d.species_id
    if species == "TODO":
        continue
    depth_path = f"data/group/xjin/species/sp-{species}/r.proc.gene99_v15-v22-agg75.spgc_specgene-ref-t25-p95.species_depth.tsv"
    _depth = pd.read_table(
        depth_path, names=["sample_id", "depth"], index_col="sample_id"
    ).depth
    depth_meta[species] = pd.Series(
        dict(species_depth_max=_depth.max(), species_depth_sum=_depth.sum())
    )

    # SPGC
    for unit in ["uhggtop", "eggnog", "cog"]:
        for tool in [
            "spgc-fit",
            # "spgc2-fit",
            # "nnmatched-m50",
            # "nnmatched-m10",
            # "nnmatched-m1",
            # "nnmatched-m0",
            "spgc-depth200",
        ]:
            accuracy_path = f"data/group/xjin/species/sp-{species}/r.proc.gene99_v15-v22-agg75.{tool}.{genome_id}.{unit}-reconstruction_accuracy.tsv"
            if not os.path.exists(accuracy_path):
                missing.append(accuracy_path)
                continue
            matched_strain = genome_to_spgc_strain.strain[genome_id]
            data = (
                pd.read_table(accuracy_path, index_col="strain")
                .assign(species=species, accuracy_path=accuracy_path, strain=lambda x: x.index)
                .sort_values("f1", ascending=False)
            )

            if matched_strain in data.index:
                benchmark[(genome_id, tool, unit, "match")] = data.loc[matched_strain]
            if data.shape[0] >= 1:
                benchmark[(genome_id, tool, unit, "top")] = data.iloc[0]
            if data.shape[0] >= 2:
                benchmark[(genome_id, tool, unit, "second")] = data.iloc[1]
        for tool in [
            "panphlan",
            # "spanda-s2",
            # "spanda-s3",
            "spanda-s4",
        ]:
            accuracy_path = f"data/group/xjin/species/sp-{species}/r.proc.gene99_v15-v22-agg75.{tool}.{genome_id}.{unit}-reconstruction_accuracy.tsv"
            if not os.path.exists(accuracy_path):
                missing.append(accuracy_path)
                continue
            data = (
                pd.read_table(accuracy_path, index_col="strain")
                .assign(species=species, accuracy_path=accuracy_path, strain=lambda x: x.index)
                .sort_values("f1", ascending=False)
            )
            if data.shape[0] >= 1:
                benchmark[(genome_id, tool, unit, "top")] = data.iloc[0]
                benchmark[(genome_id, tool, unit, "match")] = data.iloc[0]
            if data.shape[0] >= 2:
                benchmark[(genome_id, tool, unit, "second")] = data.iloc[1]
            if data.shape[0] >= 3:
                benchmark[(genome_id, tool, unit, "third")] = data.iloc[2]

    spgc_qc_path = f"data/group/xjin/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_v15-v22-agg75.spgc-fit.strain_meta-hmp2-s90-d100-a1-pos100.tsv"
    if not os.path.exists(spgc_qc_path):
        continue
    _qc = pd.read_table(spgc_qc_path, index_col="genome_id")
    if matched_strain in _qc.index:
        spgc_qc[genome_id] = _qc.loc[matched_strain]

benchmark = pd.DataFrame(benchmark.values(), index=benchmark.keys()).rename_axis(
    ["genome_id", "tool", "unit", "match"]
)
depth_meta = pd.DataFrame(depth_meta.values(), index=depth_meta.keys()).rename_axis(
    "species"
)
spgc_qc = pd.DataFrame(spgc_qc.values(), index=spgc_qc.keys())

benchmark.sort_values(["species", "tool"])

In [None]:
bins = np.linspace(0, 1, num=50)
d = benchmark.xs(("eggnog", "top"), level=("unit", "match")).f1.unstack().fillna(0)
for tool in [
    "spgc-fit",
    # "nnmatched-m0", "nnmatched-m50",
    "spgc-depth200",
]:
    plt.hist(d[tool], bins=bins, alpha=0.5, label=tool)
# plt.hist(d["nnmatched-m0"], bins=bins, alpha=0.5)
# plt.hist(d["nnmatched-m50"], bins=bins, alpha=0.5)
# plt.hist(d["spgc-depth200"], bins=bins, alpha=0.5)
# plt.hist(d["panphlan"], bins=bins, alpha=0.5)
# plt.hist(d["spanda-s2"], bins=bins, alpha=0.5)
# plt.hist(d["spanda-s3"], bins=bins, alpha=0.5)
# plt.hist(d["spanda-s4"], bins=bins, alpha=0.5)
plt.legend()
None

In [None]:
d0 = genome_to_spgc_strain.join(
    benchmark.xs(("eggnog", "match"), level=("unit", "match")).f1.unstack()
).fillna(0)#.loc[genome_list]

plt.scatter(
    "spanda-s4",
    "spgc-fit",
    data=d0,
    alpha=0.4,
    norm=mpl.colors.SymLogNorm(linthresh=0.2),
    c="strain_depth_max",
)
plt.colorbar()
# plt.scatter(d1['spanda-s2'], d1['spgc-fit'])
plt.plot([0, 1], [0, 1], lw=1, linestyle="--", color="k")

In [None]:
benchmark.xs(("uhggtop", "match", "spanda-s4"), level=("unit", "match", "tool")).sort_values('f1', ascending=False)

In [None]:
d0 = genome_to_spgc_strain.join(
    benchmark.xs(("spgc-fit", "eggnog"), level=("tool", "unit")).f1.unstack()
).fillna(0)


plt.scatter(
    "match",
    "top",
    data=d0,
    alpha=0.4,
    norm=mpl.colors.SymLogNorm(linthresh=0.2),
    # c="strain_depth_max",
)
# plt.colorbar()
# # plt.scatter(d1['spanda-s2'], d1['spgc-fit'])
# plt.plot([0, 1], [0, 1], lw=1, linestyle="--", color="k")

In [None]:
plt.scatter('precision', 'recall', data=benchmark.xs(('spgc-fit', 'eggnog', 'match'), level=('tool', 'unit', 'match')))
plt.plot([0, 1], [0, 1])

In [None]:
# Which genomes are failing and why?
(
    genome_to_spgc_strain.join(benchmark.xs(("eggnog", "match"), level=("unit", "match")).f1.unstack())
    .fillna(0)[lambda x: x["spgc-fit"] == 0]
    .join(genome)
    # .join(depth_meta, on="species_id")
    .join(
        genome.species_id.value_counts().rename("num_strains_in_species"),
        on="species_id",
    )
)

In [None]:
# What's the point of benchmarking strains that have no or very low depth?
# that just even more heavily favors the tools that I'm taking "best hit" for.
d0 = (
    genome_to_spgc_strain.join(benchmark.xs(("eggnog", "match"), level=("unit", "match")).f1.unstack())
    .fillna(0)
    .join(genome)
    # .join(depth_meta, on="species_id")
    [lambda x: x.species_depth_sum >= 0.5]
)
print(len(d0))

plt.scatter(
    "panphlan",
    "spgc-fit",
    data=d0,
    alpha=0.4,
    norm=mpl.colors.SymLogNorm(linthresh=0.2),
    c="strain_depth_max",
)
plt.colorbar()
# plt.scatter(d1['spanda-s2'], d1['spgc-fit'])
plt.plot([0, 1], [0, 1], lw=1, linestyle="--", color="k")

In [None]:
d0 = (
    genome_to_spgc_strain.join(benchmark.xs(("eggnog", "match"), level=("unit", "match")).f1.unstack())
    .fillna(0)
    .join(genome)
    # .join(depth_meta, on="species_id")
    [lambda x: x.species_depth_sum >= 0.5]
)

plt.scatter(
    "panphlan",
    "spgc-fit",
    data=d0,
    alpha=0.4,
    norm=mpl.colors.SymLogNorm(linthresh=0.2),
    c="strain_depth_max",
)
plt.colorbar()
# plt.scatter(d1['spanda-s2'], d1['spgc-fit'])
plt.plot([0, 1], [0, 1], lw=1, linestyle="--", color="k")
plt.ylim(0.55, 1.05)
plt.xlim(0.55, 1.05)

In [None]:
d0 = (
    benchmark.xs(("eggnog", "match"), level=("unit", "match")).f1
    .unstack()
    .fillna(0)
    .join(
        genome_to_spgc_strain[
            [
                "genotype_dissimilarity",
                "strain_depth_sum",
                "strain_depth_max",
            ]
        ]
    )
    .fillna({"genotype_dissimilarity": 1.0})
    .fillna(0)
)

plt.scatter(
    "genotype_dissimilarity",
    "spgc-fit",
    data=d0,
    alpha=0.4,
    c="strain_depth_max",
    norm=mpl.colors.SymLogNorm(linthresh=0.1),
)
plt.colorbar()
# plt.colorbar()
# plt.scatter(d1['spanda-s2'], d1['spgc-fit'])
# plt.plot([0, 1], [0, 1], lw=1, linestyle="--", color="k")
plt.xscale("symlog", linthresh=1e-5, linscale=0.1)

In [None]:
d0 = (
    benchmark.xs(("eggnog", "match"), level=("unit", "match")).f1.unstack()
    .fillna(0)
    .join(
        genome_to_spgc_strain[
            [
                "genotype_dissimilarity",
                "strain_depth_sum",
                "strain_depth_max",
            ]
        ]
    )
    .fillna({"genotype_dissimilarity": 1.0})
    .fillna(0)
)

plt.scatter(
    "strain_depth_sum",
    "genotype_dissimilarity",
    data=d0,
    alpha=0.4,
    c="spgc-fit",
    vmin=0,
    vmax=1,
)
plt.colorbar()
# plt.scatter(d1['spanda-s2'], d1['spgc-fit'])
# plt.plot([0, 1], [0, 1], lw=1, linestyle="--", color="k")
plt.yscale("symlog", linthresh=1e-4)
plt.xscale("symlog", linthresh=1e-2, linscale=0.1)

In [None]:
d0 = (
    benchmark.xs(("eggnog", "match"), level=("unit", "match")).f1.unstack()
    .fillna(0)
    .join(
        genome_to_spgc_strain[
            [
                "genotype_dissimilarity",
                "strain_depth_sum",
                "strain_depth_max",
            ]
        ]
    )
    .fillna({"genotype_dissimilarity": 1.0})
    .fillna(0)
)

plt.scatter(
    "strain_depth_sum",
    "genotype_dissimilarity",
    data=d0,
    alpha=0.4,
    c="spgc-fit",
    vmin=0,
    vmax=1,
)
plt.colorbar()
# plt.scatter(d1['spanda-s2'], d1['spgc-fit'])
# plt.plot([0, 1], [0, 1], lw=1, linestyle="--", color="k")
plt.yscale("symlog", linthresh=1e-4)
plt.xscale("symlog", linthresh=1e-2, linscale=0.1)

In [None]:
multi_genome_species = idxwhere(
    strain_match.join(genome.species_id)
    .reset_index()[["species_id", "genome_id"]]
    .drop_duplicates()
    .species_id.value_counts()
    > 1
)
multi_genome_species

In [None]:
d = (
    genome.join(genome_to_spgc_strain)[
        lambda x: (x.species_id != "TODO")
        & (x.species_depth_sum > 0)
        & (x.species_depth_max > 0)
        # & ~x.species_id.isin(multi_genome_species)
    ]
    .join(benchmark.xs(("spgc-fit", "eggnog", "match"), level=("tool", "unit", "match")), rsuffix="_")
    .fillna({"precision": 0, "recall": 0, "f1": 0, "jaccard": 0})
)
d

In [None]:
d = (
    genome.join(genome_to_spgc_strain)[
        lambda x: (x.species_id != "TODO")
        & (x.species_depth_sum > 0)
        & (x.species_depth_max > 0)
        # & ~x.species_id.isin(multi_genome_species)
    ]
    .join(benchmark.xs(("eggnog", "match"), level=("unit", "match")), rsuffix="_")
    # .fillna({"precision": 0, "recall": 0, "f1": 0, "jaccard": 0})
    [['precision', 'recall', 'f1', 'jaccard']]
    .unstack('tool')
    .xs('spgc-fit', level='tool', axis='columns')
    .fillna(0)
    .join(genome_to_spgc_strain.species_gene_frac)
)

plt.scatter('precision', "recall", data=d, c='species_gene_frac')
plt.colorbar()
plt.plot([0, 1], [0, 1])
print(d[["precision", "recall", "f1", "jaccard"]].quantile([0.25, 0.5, 0.75]))
print(d[d.species_gene_frac > 0.9][["precision", "recall", "f1", "jaccard"]].quantile([0.25, 0.5, 0.75]))

In [None]:
d = (
    genome.join(genome_to_spgc_strain)[
        lambda x: (x.species_id != "TODO")
        & (x.species_depth_sum > 0)
        & (x.species_depth_max > 0)
        # & ~x.species_id.isin(multi_genome_species)
    ]
    .join(benchmark.xs(("eggnog", "match"), level=("unit", "match")), rsuffix="_")
    # .fillna({"precision": 0, "recall": 0, "f1": 0, "jaccard": 0})
    .f1.unstack("tool")
    .fillna(0)
    .join(genome_to_spgc_strain.species_gene_frac)
)

d_filt = d[d.species_gene_frac > 0.9]

print((-(d.subtract(d["spgc-fit"], axis=0))).quantile([0.25, 0.5, 0.75]))
print((-(d_filt.subtract(d_filt["spgc-fit"], axis=0))).quantile([0.25, 0.5, 0.75]))

In [None]:
d = (
    genome.join(genome_to_spgc_strain)[
        lambda x: (x.species_id != "TODO")
        & (x.species_depth_sum > 0)
        & (x.species_depth_max > 0)
        # & ~x.species_id.isin(multi_genome_species)
    ]
    .join(benchmark.xs(("eggnog", "match"), level=("unit", "match")), rsuffix="_")
    # .fillna({"precision": 0, "recall": 0, "f1": 0, "jaccard": 0})
    .precision.unstack("tool")
    .fillna(0)
    .join(genome_to_spgc_strain.species_gene_frac)
)

d_filt = d[d.species_gene_frac > 0.9]

print((-(d.subtract(d["spgc-fit"], axis=0))).quantile([0.25, 0.5, 0.75]))
print((-(d_filt.subtract(d_filt["spgc-fit"], axis=0))).quantile([0.25, 0.5, 0.75]))

In [None]:
d = (
    genome.join(genome_to_spgc_strain)[
        lambda x: (x.species_id != "TODO")
        & (x.species_depth_sum > 0)
        & (x.species_depth_max > 0)
        # & ~x.species_id.isin(multi_genome_species)
    ]
    .join(benchmark.xs(("eggnog", "match"), level=("unit", "match")), rsuffix="_")
    # .fillna({"precision": 0, "recall": 0, "f1": 0, "jaccard": 0})
    .recall.unstack("tool")
    .fillna(0)
    .join(genome_to_spgc_strain.species_gene_frac)
)

d_filt = d[d.species_gene_frac > 0.9]

print((-(d.subtract(d["spgc-fit"], axis=0))).quantile([0.25, 0.5, 0.75]))
print((-(d_filt.subtract(d_filt["spgc-fit"], axis=0))).quantile([0.25, 0.5, 0.75]))

In [None]:
unit = "eggnog"
match = "match"

d1 = (
    benchmark.xs((unit, match), level=("unit", "match"))
    .join(genome.join(depth_meta, on="species_id"))[
        lambda x: x.species_depth_sum > 0.1
    ][["f1", "precision", "recall", "jaccard"]]
    .unstack()
    .rename_axis(columns=["score", "tool"])
    .fillna(0)
)

### Figure A ###
_tool_comparison_order = ["panphlan", "spanda-s4"]
_score_order = ["precision", "recall", "f1"]
fig, axs = plt.subplots(
    len(_score_order),
    len(_tool_comparison_order),
    figsize=(2.5 * len(_tool_comparison_order) + 0.5, 2.5 * len(_score_order)),
    sharex=True,
    sharey=True,
)
y = "spgc2-fit"
nbins = 20
left_bound = 0.0
bins = [0] + list(np.linspace(left_bound, 1, num=nbins + 1)[1:])
for score, ax_row in zip(_score_order, axs):
    d2 = d1.xs(score, level="score", axis="columns")
    print(score, "SPGC IQR:", (d2[y]).quantile([0.25, 0.5, 0.75]).round(3).tolist())
    for x, ax in zip(_tool_comparison_order, ax_row):
        *_, cbar_artist = ax.hist2d(
            x,
            y,
            data=d2,
            bins=bins,
            cmin=1,
            norm=mpl.colors.PowerNorm(1 / 1, vmin=0, vmax=21),
            cmap="magma_r",
        )
        ax.plot([left_bound, 1], [left_bound, 1], lw=1, linestyle="--", color="k")
        ax.set_aspect(1)
        print(
            "compared to:",
            x,
            (d2[y] - d2[x]).quantile([0.05, 0.25, 0.5, 0.75, 0.95]).round(3).tolist(),
            sp.stats.wilcoxon(d2[x], d2[y]).pvalue,
        )
        ax.set_xticks([0, 0.6, 0.8, 1.0])
        ax.set_yticks([0, 0.6, 0.8, 1.0])

ax.set_xlim(left_bound, 1)
ax.set_ylim(left_bound, 1)

for score, ax_row in zip(_score_order, axs):
    ax_row[0].set_ylabel(score)
    # *_, artist = ax_row[-1].hist2d(x, y, data=d1.head(0), bins=np.linspace(0, 1, num=21), norm=mpl.colors.PowerNorm(1/2, vmin=0, vmax=30), cmap='Blues')
    # fig.colorbar(artist, cax=ax_row[-1])

for tool_comparison, ax_col in zip(_tool_comparison_order, axs.T):
    ax_col[0].set_title(tool_comparison)

fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.025, 0.67])
fig.colorbar(
    cbar_artist, cax=cbar_ax, ticks=[1, 10, 20, 30, 40, 50], label="count strains"
)
# fig.tight_layout(rect=(0, 0, 0.85, 0.67))
# ax.set_xlabel(x)
# ax.set_ylabel(y)

In [None]:
unit = "eggnog"
match = "match"

d1 = (
    benchmark.xs((unit, match), level=("unit", "match"))
    .join(genome.join(depth_meta, on="species_id"))[
        lambda x: x.species_depth_sum >= 0.1
    ][["f1", "precision", "recall", "jaccard"]]
    .unstack()
    .rename_axis(columns=["score", "tool"])
    .fillna(0)
)

### Figure A ###
_tool_comparison_order = ["panphlan", "spanda-s4", "spgc-depth200"]

_score_order = ["precision", "recall", "f1"]
fig, axs = plt.subplots(
    len(_score_order),
    len(_tool_comparison_order),
    figsize=(2.5 * len(_tool_comparison_order), 2.5 * len(_score_order)),
    sharex=True,
    sharey=True,
)
y = "spgc-fit"
for score, ax_row in zip(_score_order, axs):
    d2 = (d1.xs(score, level="score", axis="columns").join(
        genome.join(genome_to_spgc_strain).join(
            depth_meta, on="species_id", rsuffix="_"
        )
    )
          # .join(spgc_qc.species_gene_frac)
         )

    print(score, "SPGC IQR:", (d2[y]).quantile([0.25, 0.5, 0.75]).round(3).tolist())
    for x, ax in zip(_tool_comparison_order, ax_row):
        ax.scatter(
            x,
            y,
            data=d2,
            s=5,
            c="species_gene_frac",
            norm=mpl.colors.PowerNorm(1/2),
        )
        ax.plot([left_bound, 1], [left_bound, 1], lw=1, linestyle="--", color="k")
        ax.set_aspect(1)
        print(
            "compared to:",
            x,
            (d2[y] - d2[x]).quantile([0.05, 0.25, 0.5, 0.75, 0.95]).round(3).tolist(),
            sp.stats.wilcoxon(d2[x], d2[y]).pvalue,
        )
        ax.set_xticks([0, 0.6, 0.8, 1.0])
        ax.set_yticks([0, 0.6, 0.8, 1.0])


pad = 1e-2
# ax.set_yscale(
#     "function",
#     functions=(lambda x: -np.log(1 + pad - x), lambda y: (1 + pad - np.exp(-y))),
# )
# ax.set_yticks(np.unique([0.0, 0.5, 0.75, 0.9, 0.95, 0.99]))
ax.set_ylim(-0.01, 1.01)

# ax.set_xscale(
#     "function",
#     functions=(lambda x: -np.log(1 + pad - x), lambda y: (1 + pad - np.exp(-y))),
# )
# ax.set_xticks(np.unique([0.0, 0.5, 0.75, 0.9, 0.95, 0.99]))
ax.set_xlim(-0.01, 1.01)

# ax.set_xlim(left_bound, 1)
# ax.set_ylim(left_bound, 1)

for score, ax_row in zip(_score_order, axs):
    ax_row[0].set_ylabel(score)
    # *_, artist = ax_row[-1].hist2d(x, y, data=d1.head(0), bins=np.linspace(0, 1, num=21), norm=mpl.colors.PowerNorm(1/2, vmin=0, vmax=30), cmap='Blues')
    # fig.colorbar(artist, cax=ax_row[-1])

for tool_comparison, ax_col in zip(_tool_comparison_order, axs.T):
    ax_col[0].set_title(tool_comparison)
    lib.plot.rotate_xticklabels(ax=ax_col[-1], rotation=45)

In [None]:
unit = "eggnog"
match = "match"

d1 = (
    benchmark.xs((unit, match), level=("unit", "match"))
    .join(genome.join(depth_meta, on="species_id"))[
        lambda x: x.species_depth_sum >= 0.1
    ][["f1", "precision", "recall", "jaccard"]]
    .unstack()
    .rename_axis(columns=["score", "tool"])
    .fillna(0)
)

### Figure A ###
_tool_comparison_order = [
    "panphlan",
    "spanda-s4",
    # "spgc-depth200",
]
_score_order = ["precision", "recall", "f1"]
fig, axs = plt.subplots(
    len(_score_order),
    len(_tool_comparison_order),
    figsize=(2.5 * len(_tool_comparison_order), 2.5 * len(_score_order)),
    sharex=True,
    sharey=True,
)
y = "spgc-fit"
for score, ax_row in zip(_score_order, axs):
    d2 = (d1.xs(score, level="score", axis="columns").join(
        genome.join(genome_to_spgc_strain).join(
            depth_meta, on="species_id", rsuffix="_"
        )
    )
          # .join(genome_to_spgc_strain.species_gene_frac)
          .sort_values('species_gene_frac')
         )
    print(score, "SPGC IQR:", (d2[y]).quantile([0.25, 0.5, 0.75]).round(3).tolist())
    for x, ax in zip(_tool_comparison_order, ax_row):
        ax.scatter(
            x,
            y,
            data=d2,
            s=10,
            c="species_gene_frac",
            norm=mpl.colors.PowerNorm(3, vmin=0, vmax=1),
            cmap='copper',
        )
        ax.plot([left_bound, 1], [left_bound, 1], lw=1, linestyle="--", color="k")
        ax.set_aspect(1)
        print(
            "compared to:",
            x,
            (d2[y] - d2[x]).quantile([0.05, 0.25, 0.5, 0.75, 0.95]).round(3).tolist(),
            sp.stats.wilcoxon(d2[x], d2[y]).pvalue,
        )
        # ax.set_xticks([0, 0.6, 0.8, 1.0])
        # ax.set_yticks([0, 0.6, 0.8, 1.0])


pad = 1e-2
# ax.set_yscale(
#     "function",
#     functions=(lambda x: -np.log(1 + pad - x), lambda y: (1 + pad - np.exp(-y))),
# )
# ax.set_yticks(np.unique([0.0, 0.5, 0.8, 0.9, 0.95, 0.99]))
ax.set_yticks([0, 0.2, 0.4, 0.6, 0.8, 1.0])
ax.set_ylim(-0.0, 1.0)

# ax.set_xscale(
#     "function",
#     functions=(lambda x: -np.log(1 + pad - x), lambda y: (1 + pad - np.exp(-y))),
# )
ax.set_xticks([0, 0.2, 0.4, 0.6, 0.8, 1.0])
ax.set_xlim(-0.0, 1.0)

# ax.set_xlim(left_bound, 1)
# ax.set_ylim(left_bound, 1)

for score, ax_row in zip(_score_order, axs):
    ax_row[0].set_ylabel(score)
    # *_, artist = ax_row[-1].hist2d(x, y, data=d1.head(0), bins=np.linspace(0, 1, num=21), norm=mpl.colors.PowerNorm(1/2, vmin=0, vmax=30), cmap='Blues')
    # fig.colorbar(artist, cax=ax_row[-1])

for tool_comparison, ax_col in zip(_tool_comparison_order, axs.T):
    ax_col[0].set_title(tool_comparison)
    lib.plot.rotate_xticklabels(ax=ax_col[-1], rotation=45)

fig = plt.figure(figsize=(5, 2))
plt.scatter([], [], c=[], norm=mpl.colors.PowerNorm(3, vmin=0, vmax=1),
            cmap='copper')
plt.colorbar(ticks=[0.5, 0.6, 0.7, 0.8, 0.9, 1.0], orientation='horizontal')
ax = plt.gca()
ax.set_visible(False)

In [None]:
unit = "uhggtop"
match = "match"

d1 = (
    benchmark.xs((unit, match), level=("unit", "match"))
    .join(genome.join(depth_meta, on="species_id"))[
        lambda x: x.species_depth_sum >= 0.1
    ][["f1", "precision", "recall", "jaccard"]]
    .unstack()
    .rename_axis(columns=["score", "tool"])
    .fillna(0)
)

### Figure A ###
_tool_comparison_order = ["panphlan", "spanda-s4", "spgc-depth200"]
_score_order = ["precision", "recall", "f1"]
fig, axs = plt.subplots(
    len(_score_order),
    len(_tool_comparison_order),
    figsize=(2.5 * len(_tool_comparison_order), 2.5 * len(_score_order)),
    sharex=True,
    sharey=True,
)
y = "spgc-fit"
for score, ax_row in zip(_score_order, axs):
    d2 = (d1.xs(score, level="score", axis="columns").join(
        genome.join(genome_to_spgc_strain).join(
            depth_meta, on="species_id", rsuffix="_"
        )
    )
          # .join(spgc_qc.species_gene_frac)
          # .sort_values('species_gene_frac')
         )
    print(score, "SPGC IQR:", (d2[y]).quantile([0.25, 0.5, 0.75]).round(3).tolist())
    for x, ax in zip(_tool_comparison_order, ax_row):
        ax.scatter(
            x,
            y,
            data=d2,
            s=10,
            # c="species_gene_frac",
            norm=mpl.colors.PowerNorm(3, vmin=0, vmax=1),
            cmap='copper',
        )
        ax.plot([left_bound, 1], [left_bound, 1], lw=1, linestyle="--", color="k")
        ax.set_aspect(1)
        print(
            "compared to:",
            x,
            (d2[y] - d2[x]).quantile([0.05, 0.25, 0.5, 0.75, 0.95]).round(3).tolist(),
            sp.stats.wilcoxon(d2[x], d2[y]).pvalue,
        )
        # ax.set_xticks([0, 0.6, 0.8, 1.0])
        # ax.set_yticks([0, 0.6, 0.8, 1.0])


pad = 1e-2
# ax.set_yscale(
#     "function",
#     functions=(lambda x: -np.log(1 + pad - x), lambda y: (1 + pad - np.exp(-y))),
# )
# ax.set_yticks(np.unique([0.0, 0.5, 0.8, 0.9, 0.95, 0.99]))
ax.set_yticks([0, 0.2, 0.4, 0.6, 0.8, 1.0])
ax.set_ylim(-0.0, 1.0)

# ax.set_xscale(
#     "function",
#     functions=(lambda x: -np.log(1 + pad - x), lambda y: (1 + pad - np.exp(-y))),
# )
ax.set_xticks([0, 0.2, 0.4, 0.6, 0.8, 1.0])
ax.set_xlim(-0.0, 1.0)

# ax.set_xlim(left_bound, 1)
# ax.set_ylim(left_bound, 1)

for score, ax_row in zip(_score_order, axs):
    ax_row[0].set_ylabel(score)
    # *_, artist = ax_row[-1].hist2d(x, y, data=d1.head(0), bins=np.linspace(0, 1, num=21), norm=mpl.colors.PowerNorm(1/2, vmin=0, vmax=30), cmap='Blues')
    # fig.colorbar(artist, cax=ax_row[-1])

for tool_comparison, ax_col in zip(_tool_comparison_order, axs.T):
    ax_col[0].set_title(tool_comparison)
    lib.plot.rotate_xticklabels(ax=ax_col[-1], rotation=45)

fig = plt.figure(figsize=(5, 2))
plt.scatter([], [], c=[], norm=mpl.colors.PowerNorm(3, vmin=0, vmax=1),
            cmap='copper')
plt.colorbar(ticks=[0.5, 0.6, 0.7, 0.8, 0.9, 1.0], orientation='horizontal')
ax = plt.gca()
ax.set_visible(False)

In [None]:
unit = "cog"
match = "match"

d1 = (
    benchmark.xs((unit, match), level=("unit", "match"))
    .join(genome.join(depth_meta, on="species_id"))[
        lambda x: x.species_depth_sum >= 0.1
    ][["f1", "precision", "recall", "jaccard"]]
    .unstack()
    .rename_axis(columns=["score", "tool"])
    .fillna(0)
)

### Figure A ###
_tool_comparison_order = ["panphlan", "spanda-s4", "spgc-depth200"]
_score_order = ["precision", "recall", "f1"]
fig, axs = plt.subplots(
    len(_score_order),
    len(_tool_comparison_order),
    figsize=(2.5 * len(_tool_comparison_order), 2.5 * len(_score_order)),
    sharex=True,
    sharey=True,
)
y = "spgc-fit"
for score, ax_row in zip(_score_order, axs):
    d2 = (d1.xs(score, level="score", axis="columns").join(
        genome.join(genome_to_spgc_strain).join(
            depth_meta, on="species_id", rsuffix="_"
        )
    )
          # .join(spgc_qc.species_gene_frac)
          # .sort_values('species_gene_frac')
         )
    print(score, "SPGC IQR:", (d2[y]).quantile([0.25, 0.5, 0.75]).round(3).tolist())
    for x, ax in zip(_tool_comparison_order, ax_row):
        ax.scatter(
            x,
            y,
            data=d2,
            s=10,
            # c="species_gene_frac",
            norm=mpl.colors.PowerNorm(3, vmin=0, vmax=1),
            cmap='copper',
        )
        ax.plot([left_bound, 1], [left_bound, 1], lw=1, linestyle="--", color="k")
        ax.set_aspect(1)
        print(
            "compared to:",
            x,
            (d2[y] - d2[x]).quantile([0.05, 0.25, 0.5, 0.75, 0.95]).round(3).tolist(),
            sp.stats.wilcoxon(d2[x], d2[y]).pvalue,
        )
        # ax.set_xticks([0, 0.6, 0.8, 1.0])
        # ax.set_yticks([0, 0.6, 0.8, 1.0])


pad = 1e-2
# ax.set_yscale(
#     "function",
#     functions=(lambda x: -np.log(1 + pad - x), lambda y: (1 + pad - np.exp(-y))),
# )
# ax.set_yticks(np.unique([0.0, 0.5, 0.8, 0.9, 0.95, 0.99]))
ax.set_yticks([0, 0.2, 0.4, 0.6, 0.8, 1.0])
ax.set_ylim(-0.0, 1.0)

# ax.set_xscale(
#     "function",
#     functions=(lambda x: -np.log(1 + pad - x), lambda y: (1 + pad - np.exp(-y))),
# )
ax.set_xticks([0, 0.2, 0.4, 0.6, 0.8, 1.0])
ax.set_xlim(-0.0, 1.0)

# ax.set_xlim(left_bound, 1)
# ax.set_ylim(left_bound, 1)

for score, ax_row in zip(_score_order, axs):
    ax_row[0].set_ylabel(score)
    # *_, artist = ax_row[-1].hist2d(x, y, data=d1.head(0), bins=np.linspace(0, 1, num=21), norm=mpl.colors.PowerNorm(1/2, vmin=0, vmax=30), cmap='Blues')
    # fig.colorbar(artist, cax=ax_row[-1])

for tool_comparison, ax_col in zip(_tool_comparison_order, axs.T):
    ax_col[0].set_title(tool_comparison)
    lib.plot.rotate_xticklabels(ax=ax_col[-1], rotation=45)

fig = plt.figure(figsize=(5, 2))
plt.scatter([], [], c=[], norm=mpl.colors.PowerNorm(3, vmin=0, vmax=1),
            cmap='copper')
plt.colorbar(ticks=[0.5, 0.6, 0.7, 0.8, 0.9, 1.0], orientation='horizontal')
ax = plt.gca()
ax.set_visible(False)

In [None]:
unit = "eggnog"
match = "match"

d1 = (
    benchmark.xs((unit, match), level=("unit", "match"))
    .join(genome.join(depth_meta, on="species_id"))[
        lambda x: x.species_depth_sum >= 0.1
    ][["f1", "precision", "recall", "jaccard"]]
    .unstack()
    .rename_axis(columns=["score", "tool"])
    .fillna(0)
)
d2 = d1.xs('spgc-fit', level="tool", axis="columns").join(
    genome.join(genome_to_spgc_strain).join(
        depth_meta, on="species_id", rsuffix="_"
    )
).join(spgc_qc.species_gene_frac)[lambda x: x.strain != -1].sort_values('species_gene_frac')

for score in ['precision', 'recall', 'f1']:
    sns.regplot(x='species_gene_frac', y=score, data=d2)
    print(sp.stats.pearsonr(d2['species_gene_frac'], d2[score]))

In [None]:
unit = "cog"
match = "match"

d1 = (
    benchmark.xs((unit, match), level=("unit", "match"))
    .join(genome.join(depth_meta, on="species_id"))[
        lambda x: x.species_depth_sum >= 0.1
    ][["f1", "precision", "recall"]]
    .unstack()
    .rename_axis(columns=["score", "tool"])
    .fillna(0)
)

### Figure A ###
_tool_comparison_order = ["panphlan", "spanda-s4", "spgc-depth200"]

_score_order = ["precision", "recall", "f1"]
fig, axs = plt.subplots(
    len(_score_order),
    len(_tool_comparison_order),
    figsize=(2.5 * len(_tool_comparison_order), 2.5 * len(_score_order)),
    sharex=True,
    sharey=True,
)
y = "spgc-fit"
for score, ax_row in zip(_score_order, axs):
    d2 = d1.xs(score, level="score", axis="columns").join(
        genome.join(genome_to_spgc_strain).join(
            depth_meta, on="species_id", rsuffix="_"
        )
    )
    print(score, "SPGC IQR:", (d2[y]).quantile([0.25, 0.5, 0.75]).round(3).tolist())
    for x, ax in zip(_tool_comparison_order, ax_row):
        ax.scatter(
            x,
            y,
            data=d2,
            s=5,
            c="strain_depth_max",
            norm=mpl.colors.PowerNorm(1/2),
        )
        ax.plot([left_bound, 1], [left_bound, 1], lw=1, linestyle="--", color="k")
        ax.set_aspect(1)
        print(
            "compared to:",
            x,
            (d2[y] - d2[x]).quantile([0.05, 0.25, 0.5, 0.75, 0.95]).round(3).tolist(),
            sp.stats.wilcoxon(d2[x], d2[y]).pvalue,
        )
        ax.set_xticks([0, 0.6, 0.8, 1.0])
        ax.set_yticks([0, 0.6, 0.8, 1.0])


pad = 1e-2
# ax.set_yscale(
#     "function",
#     functions=(lambda x: -np.log(1 + pad - x), lambda y: (1 + pad - np.exp(-y))),
# )
ax.set_yticks(np.unique([0.0, 0.5, 0.75, 0.9, 0.95, 0.99]))
ax.set_ylim(-0.1, 1.00)

# ax.set_xscale(
#     "function",
#     functions=(lambda x: -np.log(1 + pad - x), lambda y: (1 + pad - np.exp(-y))),
# )
ax.set_xticks(np.unique([0.0, 0.5, 0.75, 0.9, 0.95, 0.99]))
ax.set_xlim(-0.1, 1.00)

# ax.set_xlim(left_bound, 1)
# ax.set_ylim(left_bound, 1)

for score, ax_row in zip(_score_order, axs):
    ax_row[0].set_ylabel(score)
    # *_, artist = ax_row[-1].hist2d(x, y, data=d1.head(0), bins=np.linspace(0, 1, num=21), norm=mpl.colors.PowerNorm(1/2, vmin=0, vmax=30), cmap='Blues')
    # fig.colorbar(artist, cax=ax_row[-1])

for tool_comparison, ax_col in zip(_tool_comparison_order, axs.T):
    ax_col[0].set_title(tool_comparison)
    lib.plot.rotate_xticklabels(ax=ax_col[-1], rotation=45)

In [None]:
benchmark#.iloc[0].accuracy_path

In [None]:
unit = "eggnog"
match = "match"

d1 = (
    benchmark.xs((unit, match), level=("unit", "match"))
    .join(genome.join(depth_meta, on="species_id"))[
        lambda x: x.species_depth_sum > 0.1
    ][["f1", "precision", "recall"]]
    .unstack()
    .rename_axis(columns=["score", "tool"])
    .fillna(0)
)

### Figure A ###
_tool_comparison_order = [
    "panphlan",
    "spanda-s4",
    # "nnmatched-m50",
    # "nnmatched-m10",
    # "nnmatched-m1",
    # "nnmatched-m0",
    "spgc-depth200",
]
_score_order = ["precision", "recall", "f1"]
fig, axs = plt.subplots(
    len(_score_order),
    len(_tool_comparison_order),
    figsize=(2.5 * len(_tool_comparison_order) + 0.5, 2.5 * len(_score_order)),
    sharex=True,
    sharey=True,
)
y = "spgc-fit"
nbins = 20
left_bound = 0.0
bins = [0] + list(np.linspace(left_bound, 1, num=nbins + 1)[1:])
for score, ax_row in zip(_score_order, axs):
    d2 = d1.xs(score, level="score", axis="columns")
    print(score, "SPGC IQR:", (d2[y]).quantile([0.25, 0.5, 0.75]).round(3).tolist())
    for x, ax in zip(_tool_comparison_order, ax_row):
        *_, cbar_artist = ax.hist2d(
            x,
            y,
            data=d2,
            bins=bins,
            cmin=1,
            norm=mpl.colors.PowerNorm(1 / 1, vmin=0, vmax=21),
            cmap="magma_r",
        )
        ax.plot([left_bound, 1], [left_bound, 1], lw=1, linestyle="--", color="k")
        ax.set_aspect(1)
        print(
            "compared to:",
            x,
            (d2[y] - d2[x]).quantile([0.05, 0.25, 0.5, 0.75, 0.95]).round(3).tolist(),
            sp.stats.wilcoxon(d2[x], d2[y]).pvalue,
        )
        ax.set_xticks([0, 0.6, 0.8, 1.0])
        ax.set_yticks([0, 0.6, 0.8, 1.0])

ax.set_xlim(left_bound, 1)
ax.set_ylim(left_bound, 1)

for score, ax_row in zip(_score_order, axs):
    ax_row[0].set_ylabel(score)
    # *_, artist = ax_row[-1].hist2d(x, y, data=d1.head(0), bins=np.linspace(0, 1, num=21), norm=mpl.colors.PowerNorm(1/2, vmin=0, vmax=30), cmap='Blues')
    # fig.colorbar(artist, cax=ax_row[-1])

for tool_comparison, ax_col in zip(_tool_comparison_order, axs.T):
    ax_col[0].set_title(tool_comparison)

fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.025, 0.67])
fig.colorbar(
    cbar_artist, cax=cbar_ax, ticks=[1, 10, 20, 30, 40, 50], label="count strains"
)
# fig.tight_layout(rect=(0, 0, 0.85, 0.67))
# ax.set_xlabel(x)
# ax.set_ylabel(y)

In [None]:
unit = "eggnog"
match = "match"

d1 = (
    benchmark.xs((unit, match), level=("unit", "match"))
    .join(genome.join(depth_meta, on="species_id"))[
        lambda x: x.species_depth_sum > 0.1
    ][["f1", "precision", "recall"]]
    .unstack()
    .rename_axis(columns=["score", "tool"])
    .fillna(0)
)

d1.f1[["spgc-fit", "spgc-depth200"]].sort_values("spgc-fit").join(genome)

In [None]:
benchmark.xs(('eggnog', 'match', 'spgc-fit'), level=('unit', 'match', 'tool')).sort_values('f1', ascending=False).join(genome.join(depth_meta, on="species_id")).loc[lambda x: x.species_id.isin(multi_genome_species)]

In [None]:
benchmark.xs(('eggnog', 'match', 'Veillonella-sp-3-1-44_MAF-2'), level=('unit', 'match', 'genome_id')).sort_values('f1', ascending=False).head(20).join(genome.join(depth_meta, on="species_id"))

In [None]:
benchmark.xs(('eggnog', 'match', 'Veillonella-sp-6-1-27_MAF-2'), level=('unit', 'match', 'genome_id')).sort_values('f1', ascending=False).head(20).join(genome.join(depth_meta, on="species_id"))

In [None]:
sf.plot.plot_community(w1.drop_low_abundance_strains(0.05).community)

In [None]:
import sfacts as sf

w0 = sf.data.World.load('data/group/xjin/species/sp-100003/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts37-s85-seed0.world.nc')
w1 = sf.data.World.load('data/group/xjin/species/sp-100003/r.proc.gtpro.filt-poly00-cvrg05.ss-g10000-block0-seed0.fit-sfacts37-s85-seed0.world.nc')

print(w0.sizes, w1.sizes)
# sf.plot.plot_metagenotype(w0)
sf.plot.plot_metagenotype(w1.sel(sample=list(set(w1.sample.values) & set(w0.sample.values))).random_sample(position=634))
sf.plot.plot_community(w1.sel(sample=list(set(w1.sample.values) & set(w0.sample.values))).random_sample(position=634))

In [None]:
unit = "eggnog"
match = "match"

_genome_list = idxwhere(genome_to_spgc_strain.species_depth_sum > 0)

_tool_list = [
    "spgc-fit",
    "panphlan",
    "spanda-s4",
    "spgc-depth200",
    # "nnmatched-m50"
]
_score_list = ["precision", "recall", "f1"]
xvar = "strain_depth_max"

# _tool_palette = lib.plot.construct_ordered_palette(_tool_list)
_tool_palette = lib.plot.construct_ordered_palette_from_list(
    _tool_list, colors=["tab:blue", "tab:green", "tab:orange", "tab:purple"]
)


d0 = (
    benchmark.xs((unit, match), level=("unit", "match"))[_score_list]
    .reindex(product(_genome_list, _tool_list))
    .fillna(0)
    .join(genome)
    # [lambda x: ~x.species_id.isin(multi_genome_species)]
    .join(depth_meta, on="species_id")
    .join(genome_to_spgc_strain, rsuffix="_")
)

fig, axs = plt.subplots(
    len(_score_list),
    figsize=(5, 3 * len(_score_list)),
    sharex=True,
    sharey=True,
)

for _tool in _tool_list:
    d1 = d0.xs(_tool, level="tool")
    for _score, ax in zip(_score_list, axs.flatten()):
        fit = smf.ols(f'{_score} ~ cr({xvar}, 5, constraints="center")', data=d1).fit()
        d2 = d1.assign(
            lowess_y=lambda d: sm.nonparametric.lowess(
                d[_score], d[xvar], it=10, frac=1 / 3, return_sorted=False
            ),
            spline_y=lambda d: fit.predict(pd.DataFrame({xvar: d[xvar]})),
            xvar_rolling_average=lambda d: d.sort_values(xvar)
            .rolling(window=10)[xvar]
            .mean(),
            score_rolling_average=lambda d: d.sort_values(xvar)
            .rolling(window=10)[_score]
            .mean(),
        )
        smoothed = pd.DataFrame({xvar: np.logspace(-0.5, 4, num=100)}).assign(
            spline_y=lambda d: fit.predict(pd.DataFrame({xvar: d[xvar]})),
            lowess_y=lambda d: sm.nonparametric.lowess(
                d2[_score],
                d2[xvar],
                xvals=d[xvar],
                it=10,
                frac=1 / 3,
                return_sorted=False,
            ),
        )
        ax.scatter(
            xvar,
            _score,
            data=d2,
            label=_tool,
            s=8,
            alpha=0.5,
            facecolor="none",
            color=_tool_palette[_tool],
        )
        ax.plot(
            "xvar_rolling_average",
            "score_rolling_average",
            data=d2.sort_values(xvar),
            label="__nolegend__",
            lw=1,
            linestyle="-",
            color=_tool_palette[_tool],
        )

for _score, ax in zip(_score_list, axs.flatten()):
    ax.set_ylabel(_score)

ax.set_xscale("symlog", linthresh=1e-1, linscale=0.1)

pad = 1e-2
# ax.set_yscale(
#     "function",
#     functions=(lambda x: -np.log(1 + pad - x), lambda y: (1 + pad - np.exp(-y))),
# )
# ax.set_yticks(np.unique([0.0, 0.5, 0.75, 0.9, 0.95, 0.99]))
ax.set_ylim(-0.05, 1.05)

ax.legend(bbox_to_anchor=(1, 1))

In [None]:
unit = "uhggtop"
match = "match"

_genome_list = idxwhere(genome_to_spgc_strain.species_depth_sum > 0)

_tool_list = [
    "spgc-fit",
    "panphlan",
    "spanda-s4",
    "spgc-depth200",
    # "nnmatched-m50"
]
_score_list = ["precision", "recall", "f1"]
xvar = "strain_depth_max"

# _tool_palette = lib.plot.construct_ordered_palette(_tool_list)
_tool_palette = lib.plot.construct_ordered_palette_from_list(
    _tool_list, colors=["tab:blue", "tab:green", "tab:orange", "tab:purple"]
)


d0 = (
    benchmark.xs((unit, match), level=("unit", "match"))[_score_list]
    .reindex(product(_genome_list, _tool_list))
    .fillna(0)
    .join(genome)
    # [lambda x: ~x.species_id.isin(multi_genome_species)]
    .join(depth_meta, on="species_id")
    .join(genome_to_spgc_strain, rsuffix="_")
)

fig, axs = plt.subplots(
    len(_score_list),
    figsize=(5, 3 * len(_score_list)),
    sharex=True,
    sharey=True,
)

for _tool in _tool_list:
    d1 = d0.xs(_tool, level="tool")
    for _score, ax in zip(_score_list, axs.flatten()):
        fit = smf.ols(f'{_score} ~ cr({xvar}, 5, constraints="center")', data=d1).fit()
        d2 = d1.assign(
            lowess_y=lambda d: sm.nonparametric.lowess(
                d[_score], d[xvar], it=10, frac=1 / 3, return_sorted=False
            ),
            spline_y=lambda d: fit.predict(pd.DataFrame({xvar: d[xvar]})),
            xvar_rolling_average=lambda d: d.sort_values(xvar)
            .rolling(window=10)[xvar]
            .mean(),
            score_rolling_average=lambda d: d.sort_values(xvar)
            .rolling(window=10)[_score]
            .mean(),
        )
        smoothed = pd.DataFrame({xvar: np.logspace(-0.5, 4, num=100)}).assign(
            spline_y=lambda d: fit.predict(pd.DataFrame({xvar: d[xvar]})),
            lowess_y=lambda d: sm.nonparametric.lowess(
                d2[_score],
                d2[xvar],
                xvals=d[xvar],
                it=10,
                frac=1 / 3,
                return_sorted=False,
            ),
        )
        ax.scatter(
            xvar,
            _score,
            data=d2,
            label=_tool,
            s=8,
            alpha=0.5,
            facecolor="none",
            color=_tool_palette[_tool],
        )
        ax.plot(
            "xvar_rolling_average",
            "score_rolling_average",
            data=d2.sort_values(xvar),
            label="__nolegend__",
            lw=1,
            linestyle="-",
            color=_tool_palette[_tool],
        )

for _score, ax in zip(_score_list, axs.flatten()):
    ax.set_ylabel(_score)

ax.set_xscale("symlog", linthresh=1e-1, linscale=0.1)

pad = 1e-2
# ax.set_yscale(
#     "function",
#     functions=(lambda x: -np.log(1 + pad - x), lambda y: (1 + pad - np.exp(-y))),
# )
# ax.set_yticks(np.unique([0.0, 0.5, 0.75, 0.9, 0.95, 0.99]))
ax.set_ylim(-0.05, 1.05)

ax.legend(bbox_to_anchor=(1, 1))

In [None]:
unit = "cog"
match = "match"

_genome_list = idxwhere(genome_to_spgc_strain.species_depth_sum > 0)

_tool_list = [
    "spgc-fit",
    "panphlan",
    "spanda-s4",
    "spgc-depth200",
    # "nnmatched-m50"
]
_score_list = ["precision", "recall", "f1"]
xvar = "strain_depth_max"

# _tool_palette = lib.plot.construct_ordered_palette(_tool_list)
_tool_palette = lib.plot.construct_ordered_palette_from_list(
    _tool_list, colors=["tab:blue", "tab:green", "tab:orange", "tab:purple"]
)


d0 = (
    benchmark.xs((unit, match), level=("unit", "match"))[_score_list]
    .reindex(product(_genome_list, _tool_list))
    .fillna(0)
    .join(genome)
    # [lambda x: ~x.species_id.isin(multi_genome_species)]
    .join(depth_meta, on="species_id")
    .join(genome_to_spgc_strain, rsuffix="_")
)

fig, axs = plt.subplots(
    len(_score_list),
    figsize=(5, 3 * len(_score_list)),
    sharex=True,
    sharey=True,
)

for _tool in _tool_list:
    d1 = d0.xs(_tool, level="tool")
    for _score, ax in zip(_score_list, axs.flatten()):
        fit = smf.ols(f'{_score} ~ cr({xvar}, 5, constraints="center")', data=d1).fit()
        d2 = d1.assign(
            lowess_y=lambda d: sm.nonparametric.lowess(
                d[_score], d[xvar], it=10, frac=1 / 3, return_sorted=False
            ),
            spline_y=lambda d: fit.predict(pd.DataFrame({xvar: d[xvar]})),
            xvar_rolling_average=lambda d: d.sort_values(xvar)
            .rolling(window=10)[xvar]
            .mean(),
            score_rolling_average=lambda d: d.sort_values(xvar)
            .rolling(window=10)[_score]
            .mean(),
        )
        smoothed = pd.DataFrame({xvar: np.logspace(-0.5, 4, num=100)}).assign(
            spline_y=lambda d: fit.predict(pd.DataFrame({xvar: d[xvar]})),
            lowess_y=lambda d: sm.nonparametric.lowess(
                d2[_score],
                d2[xvar],
                xvals=d[xvar],
                it=10,
                frac=1 / 3,
                return_sorted=False,
            ),
        )
        ax.scatter(
            xvar,
            _score,
            data=d2,
            label=_tool,
            s=8,
            alpha=0.5,
            facecolor="none",
            color=_tool_palette[_tool],
        )
        ax.plot(
            "xvar_rolling_average",
            "score_rolling_average",
            data=d2.sort_values(xvar),
            label="__nolegend__",
            lw=1,
            linestyle="-",
            color=_tool_palette[_tool],
        )

for _score, ax in zip(_score_list, axs.flatten()):
    ax.set_ylabel(_score)

ax.set_xscale("symlog", linthresh=1e-1, linscale=0.1)

pad = 1e-2
# ax.set_yscale(
#     "function",
#     functions=(lambda x: -np.log(1 + pad - x), lambda y: (1 + pad - np.exp(-y))),
# )
# ax.set_yticks(np.unique([0.0, 0.5, 0.75, 0.9, 0.95, 0.99]))
ax.set_ylim(-0.05, 1.05)

ax.legend(bbox_to_anchor=(1, 1))

In [None]:
d0 = (
    benchmark.xs((unit, match, "spgc-fit"), level=("unit", "match", "tool"))[_score_list]
    .fillna(0)
    .reindex(genome_to_spgc_strain.index, fill_value=0)
    .join(genome_to_spgc_strain, rsuffix="_")
    .join(spgc_qc.species_gene_frac)

)

d1 = d0[lambda x: x.species_depth_sum > 0]
d1_bad_qc =  d1[lambda x: x.species_gene_frac < 0.9]
d1_good_qc = d1[lambda x: x.species_gene_frac >= 0.9]

bins = np.linspace(0, 1, num=21)
plt.hist(d1_bad_qc.f1, alpha=0.5, bins=bins, color='tab:orange')
plt.hist(d1_good_qc.f1, alpha=0.5, bins=bins, color='tab:blue')

# plt.axvline(d1_bad_qc.f1.median(), color='tab:orange')
# plt.axvline(d1_good_qc.f1.median(), color='tab:blue')

In [None]:
unit = "eggnog"
match = "match"

d1 = (
    benchmark.xs((unit, match), level=("unit", "match"))
    .join(genome.join(depth_meta, on="species_id"))[
        lambda x: x.species_id.isin(multi_genome_species)
    ][["f1", "precision", "recall"]]
    .unstack()
    .rename_axis(columns=["score", "tool"])
    .fillna(0)
)

### Figure A ###
_tool_comparison_order = [
    "panphlan",
    "spanda-s4",
    # "nnmatched-m50",
    # "nnmatched-m10",
    # "nnmatched-m1",
    # "nnmatched-m0",
    "spgc-depth200",
]
_score_order = ["precision", "recall", "f1"]
fig, axs = plt.subplots(
    len(_score_order),
    len(_tool_comparison_order),
    figsize=(2.5 * len(_tool_comparison_order) + 0.5, 2.5 * len(_score_order)),
    sharex=True,
    sharey=True,
)
y = "spgc-fit"
nbins = 20
left_bound = 0.0
bins = [0] + list(np.linspace(left_bound, 1, num=nbins + 1)[1:])
for score, ax_row in zip(_score_order, axs):
    d2 = d1.xs(score, level="score", axis="columns")
    print(score, "SPGC IQR:", (d2[y]).quantile([0.25, 0.5, 0.75]).round(3).tolist())
    for x, ax in zip(_tool_comparison_order, ax_row):
        *_, cbar_artist = ax.hist2d(
            x,
            y,
            data=d2,
            bins=bins,
            cmin=1,
            norm=mpl.colors.PowerNorm(1 / 1, vmin=0, vmax=21),
            cmap="magma_r",
        )
        ax.plot([left_bound, 1], [left_bound, 1], lw=1, linestyle="--", color="k")
        ax.set_aspect(1)
        print(
            "compared to:",
            x,
            (d2[y] - d2[x]).quantile([0.05, 0.25, 0.5, 0.75, 0.95]).round(3).tolist(),
            sp.stats.wilcoxon(d2[x], d2[y]).pvalue,
        )
        ax.set_xticks([0, 0.6, 0.8, 1.0])
        ax.set_yticks([0, 0.6, 0.8, 1.0])

ax.set_xlim(left_bound, 1)
ax.set_ylim(left_bound, 1)

for score, ax_row in zip(_score_order, axs):
    ax_row[0].set_ylabel(score)
    # *_, artist = ax_row[-1].hist2d(x, y, data=d1.head(0), bins=np.linspace(0, 1, num=21), norm=mpl.colors.PowerNorm(1/2, vmin=0, vmax=30), cmap='Blues')
    # fig.colorbar(artist, cax=ax_row[-1])

for tool_comparison, ax_col in zip(_tool_comparison_order, axs.T):
    ax_col[0].set_title(tool_comparison)

fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.025, 0.67])
fig.colorbar(
    cbar_artist, cax=cbar_ax, ticks=[1, 10, 20, 30, 40, 50], label="count strains"
)
# fig.tight_layout(rect=(0, 0, 0.85, 0.67))
# ax.set_xlabel(x)
# ax.set_ylabel(y)

In [None]:
unit = "eggnog"
match = "match"

d1 = (
    benchmark.xs((unit, match), level=("unit", "match"))
    .join(genome.join(depth_meta, on="species_id"))[
        lambda x: x.species_id.isin(multi_genome_species)
    ][["f1", "precision", "recall"]]
    .unstack()
    .rename_axis(columns=["score", "tool"])
    .fillna(0)
)

### Figure A ###
_tool_comparison_order = [
    "panphlan",
    "spanda-s4",
    # "nnmatched-m50",
    # "nnmatched-m10",
    # "nnmatched-m1",
    # "nnmatched-m0",
    "spgc-depth200",
]
_score_order = ["precision", "recall", "f1"]
fig, axs = plt.subplots(
    len(_score_order),
    len(_tool_comparison_order),
    figsize=(2.5 * len(_tool_comparison_order), 2.5 * len(_score_order)),
    sharex=True,
    sharey=True,
)
y = "spgc-fit"
for score, ax_row in zip(_score_order, axs):
    d2 = d1.xs(score, level="score", axis="columns")
    print(score, "SPGC IQR:", (d2[y]).quantile([0.25, 0.5, 0.75]).round(3).tolist())
    for x, ax in zip(_tool_comparison_order, ax_row):
        ax.scatter(
            x,
            y,
            data=d2,
            s=5,
        )
        ax.plot([left_bound, 1], [left_bound, 1], lw=1, linestyle="--", color="k")
        ax.set_aspect(1)
        print(
            "compared to:",
            x,
            (d2[y] - d2[x]).quantile([0.05, 0.25, 0.5, 0.75, 0.95]).round(3).tolist(),
            sp.stats.wilcoxon(d2[x], d2[y]).pvalue,
        )
        ax.set_xticks([0, 0.6, 0.8, 1.0])
        ax.set_yticks([0, 0.6, 0.8, 1.0])


pad = 1e-2
ax.set_yscale(
    "function",
    functions=(lambda x: -np.log(1 + pad - x), lambda y: (1 + pad - np.exp(-y))),
)
ax.set_yticks(np.unique([0.0, 0.5, 0.9, 0.99, 1.00]))
ax.set_ylim(-0.1, 1.00)

ax.set_xscale(
    "function",
    functions=(lambda x: -np.log(1 + pad - x), lambda y: (1 + pad - np.exp(-y))),
)
ax.set_xticks(np.unique([0.0, 0.5, 0.9, 0.99, 1.00]))
ax.set_xlim(-0.1, 1.00)

# ax.set_xlim(left_bound, 1)
# ax.set_ylim(left_bound, 1)

for score, ax_row in zip(_score_order, axs):
    ax_row[0].set_ylabel(score)
    # *_, artist = ax_row[-1].hist2d(x, y, data=d1.head(0), bins=np.linspace(0, 1, num=21), norm=mpl.colors.PowerNorm(1/2, vmin=0, vmax=30), cmap='Blues')
    # fig.colorbar(artist, cax=ax_row[-1])

for tool_comparison, ax_col in zip(_tool_comparison_order, axs.T):
    ax_col[0].set_title(tool_comparison)

In [None]:
motu_depth = (
    pd.read_table(
        "data/group/xjin/r.proc.gtpro.species_depth.tsv",
        index_col=["sample", "species_id"],
    )
    .depth.unstack(fill_value=0)
    # .rename(columns=str, index=lambda x: "CF_" + str(int(x.split("_")[1])))
    # .rename({'CF_15': 'CF_11', 'CF_11': 'CF_15'})  # Sample swap
)
motu_rabund = motu_depth.divide(motu_depth.sum(1), axis=0)

motu_rabund

In [None]:
sotu_depth = []
missing_files = []
for species_id in motu_depth.columns:
    path = f"data/group/xjin/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.comm.tsv"
    try:
        d = (
            pd.read_table(path, index_col=["sample", "strain"])
            .squeeze()
            .unstack()
            # .rename(columns=str, index=lambda x: "CF_" + str(int(x.split("_")[1])))
            # .rename({'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap.
        )
    except FileNotFoundError:
        missing_files.append(path)
        d = pd.DataFrame([])
    _keep_strains = idxwhere(d.sum() > 0.05)
    assert d.index.isin(motu_depth.index).all()
    d = d.reindex(index=motu_depth.index, columns=_keep_strains, fill_value=0)
    d = d.assign(__other=lambda x: 1 - x.sum(1)).rename(columns={"__other": -1})
    d[d < 0] = 0
    d = d.divide(d.sum(1), axis=0)
    d = d.multiply(motu_depth[species_id], axis=0)
    d = d.rename(columns=lambda s: f"{species_id}_{s}")
    sotu_depth.append(d)
sotu_depth = pd.concat(sotu_depth, axis=1)
sotu_rabund = sotu_depth.divide(sotu_depth.sum(1), axis=0)
len(motu_depth.columns), len(missing_files)

In [None]:
xjin_sotu_depth = sotu_depth[lambda x: x.index.str.startswith('xjin_')]
xjin_sotu_depth = xjin_sotu_depth[idxwhere((xjin_sotu_depth.sum() > 0.01))]
plt.hist(np.log10(xjin_sotu_depth.sum()), bins=20)

In [None]:
xjin_sotu_strain_list = idxwhere(xjin_sotu_depth.sum() > 1)
other_strain_depth = xjin_sotu_depth.drop(columns=xjin_sotu_strain_list).sum(1)

d = xjin_sotu_depth[xjin_sotu_strain_list].assign(other=other_strain_depth)
sns.clustermap(d.T, norm=mpl.colors.SymLogNorm(linthresh=0.1), metric='cosine', figsize=(5, 8), xticklabels=0, yticklabels=0)

In [None]:
xjin_motu_depth = motu_depth[lambda x: x.index.str.startswith('xjin_')].rename(columns=str)
xjin_motu_depth = xjin_motu_depth[idxwhere((xjin_motu_depth.sum() > 0.01))]
plt.hist(np.log10(xjin_motu_depth.sum()), bins=20)

In [None]:
species_has_genome_list = genome_to_spgc_strain.assign(species=genome.species_id)[lambda x: x.species != 'TODO'].species.values

In [None]:
xjin_motu_strain_list = idxwhere(xjin_motu_depth.sum() > 1)
d = xjin_motu_depth[xjin_motu_strain_list]
_num_genomes_per_species = genome_to_spgc_strain.assign(species=genome.species_id).species.value_counts().reindex(d.columns, fill_value=0)
_num_genomes_palette = lib.plot.construct_ordered_palette(np.arange(_num_genomes_per_species.max()) + 1, cm='viridis', other='white')

g = sns.clustermap(
    d.T.rename_axis(index=''),
    cmap="pink",
    norm=mpl.colors.SymLogNorm(linthresh=0.1, vmin=0, vmax=1e4),
    metric="cosine",
    figsize=(5, 8),
    xticklabels=0,
    yticklabels=0,
    dendrogram_ratio=(0.1, 0.1),
    row_colors=_num_genomes_per_species.map(_num_genomes_palette),
)
g.ax_cbar.set_visible(False)
# See https://stackoverflow.com/questions/65677031/move-seaborn-clustermap-row-colors-bar-to-the-other-side-of-the-plot
ax_row_colors = g.ax_row_colors
box = ax_row_colors.get_position()
box_heatmap = g.ax_heatmap.get_position()
ax_row_colors.set_position([box_heatmap.max[0] + 0.01, box.y0, box.width*1.5, box.height])


fig = plt.figure(figsize=(5, 2))
plt.scatter(
    [],
    [],
    c=[],
    norm=mpl.colors.SymLogNorm(linthresh=0.1, vmin=0, vmax=1e3),
    cmap="pink",
)
plt.colorbar(orientation="horizontal")
ax = plt.gca()
ax.set_visible(False)

fig, ax = plt.subplots(figsize=(3, 1))
for n in np.arange(_num_genomes_per_species.max()) + 1:
    ax.scatter([], [], color=_num_genomes_palette[n], label=n, lw=0.5, edgecolor='k', marker='s')
ax.legend(ncols=2)
lib.plot.hide_axes_and_spines(ax)