## Preamble

### Project Template

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

### Imports

In [None]:
import os
import subprocess
import sys
import time
from datetime import datetime
from glob import glob
from itertools import chain, product
from tempfile import mkstemp

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import statsmodels.api as sm
import statsmodels.formula.api as smf
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

## Style

In [None]:
sns.set_context("paper")
plt.rcParams["figure.dpi"] = 300

# Load and Check Data

In [None]:
genome_list = idxwhere(
    pd.read_table("meta/genome_group.tsv", index_col="genome_id").genome_group_id
    == "xjin"
)
len(genome_list)

In [None]:
genome0 = pd.read_table("meta/genome.tsv", index_col="genome_id").loc[genome_list]
assert genome0.index.is_unique
genome = genome0.loc[genome_list][lambda x: x.species_id != "UNKNOWN"]

In [None]:
species_list = genome.species_id.unique()
genome.species_id.value_counts().agg(["sum", "count"])

### Taxonomy

In [None]:
def parse_taxonomy_string(taxonomy_string):
    values = taxonomy_string.split(";")
    return pd.Series(values, index=["d__", "p__", "c__", "o__", "f__", "g__", "s__"])

In [None]:
species_taxonomy_inpath = "ref/uhgg_genomes_all_v2.tsv"

species_taxonomy = (
    pd.read_table(species_taxonomy_inpath)[lambda x: x.Genome == x.Species_rep]
    .assign(species_id=lambda x: "1" + x.MGnify_accession.str.split("-").str[2])
    .set_index("species_id")
    .Lineage.apply(parse_taxonomy_string)
)
species_taxonomy.loc[species_list]

In [None]:
species_taxonomy.loc[species_list].p__.value_counts()

In [None]:
phylum_order = [
    "p__Euryarchaeota",
    "p__Thermoplasmatota",
    "p__Firmicutes",
    "p__Firmicutes_A",
    "p__Firmicutes_C",
    # "p__Firmicutes_B", # None in species_list1
    # "p__Firmicutes_G", # B/G/I not sure how related to C or A
    # "p__Firmicutes_I", #
    # "p__Cyanobacteria", # None in species_list1
    "p__Actinobacteriota",
    "p__Synergistota",
    "p__Fusobacteriota",
    "p__Campylobacterota",
    "p__Proteobacteria",
    "p__Desulfobacterota_A",
    "p__Bacteroidota",
    "p__Verrucomicrobiota",
    # "dummy0", # 18
    # "dummy1", # 19
    # "dummy2", # 20
]

phylum_palette = lib.plot.construct_ordered_palette(
    phylum_order,
    cm="rainbow",
    desaturate_levels=[1.0, 0.5],
)

for p__ in phylum_order:
    print(p__, phylum_palette[p__])
    plt.scatter([], [], color=phylum_palette[p__], label=p__)
plt.legend(ncols=1)
lib.plot.hide_axes_and_spines()

# assert len(set(phylum_palette.values())) == len((phylum_palette.values()))

In [None]:
species_group = (
    pd.read_table("meta/species_group.tsv")[lambda x: x.species_group_id == "xjin"]
    .species_id.astype(str)
    .to_list()
)
len(species_group)

In [None]:
qc_code_meaning = {
    -1: "fail",
    0: "passes",
    1: "noise",
    2: "species-gene",
    3: "both",
}


def assign_qc_code(x):
    if x.passes_species_gene_frac:
        if x.passes_log_selected_gene_depth_ratio_std:
            return 0  # Nothing wrong
        else:
            return 1  # Doesn't pass noise threshold
    elif x.passes_log_selected_gene_depth_ratio_std:
        return 2  # Doesn't pass species_gene_frac
    else:
        return 3  # Doesn't pass either

In [None]:
# data/group/xjin/r.proc.gtpro.sfacts-fit.gene99_v15-v22-agg75.spgc_specgene-ref-t25-p95.STRAIN_MATCH_BENCHMARK_GRID.flag

strain_match = {}
missing = []
for species in species_list:
    if species == "UNKNOWN":
        continue
    strain_diss_path = f"data/group/xjin/species/sp-{species}/r.proc.gtpro.sfacts-fit.geno_matching_stats.tsv"
    strain_meta_path = f"data/group/xjin/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_v20-v23-agg75.spgc-fit.strain_meta-s95-d100-a0-pos100-std25.tsv"
    if not os.path.exists(strain_diss_path):
        missing.append(strain_diss_path)
        continue
    strain_meta = pd.read_table(strain_meta_path, index_col="genome_id").rename_axis(
        "strain"
    )
    strain_match[species] = (
        pd.read_table(strain_diss_path, index_col=["genome_id", "strain"])
        # .assign(
        #     genotype_matching_positions=lambda x: (1 - x.genotype_dissimilarity)
        #     * x.num_geno_positions_compared,
        #     genotype_dissimilarity_pc=lambda x: x.genotype_dissimilarity
        #     + (1 / x.num_geno_positions_compared),
        #     strain_match_path=strain_match_path,
        # )
        .join(strain_meta, on="strain")
    )
strain_match = (
    pd.concat(strain_match.values()).reset_index().set_index(["genome_id", "strain"])
).assign(qc_code=lambda x: x.apply(assign_qc_code, axis=1))
strain_match

In [None]:
# NOTE: We match strains based on the lowest genotype_dissimilarity.
genome_to_spgc_strain = (
    strain_match.reset_index()
    .sort_values(
        ["genotype_dissimilarity", "num_strain_sample"], ascending=(True, False)
    )
    .groupby("genome_id")
    .head(1)
    .set_index("genome_id")
)

plt.hist(genome_to_spgc_strain.genotype_dissimilarity, bins=np.logspace(-5, 0))
plt.xscale("log")

genome_to_spgc_strain

In [None]:
# data/group/xjin/r.proc.gene99_v15-v22-agg75.spgc_specgene-ref-t25-p95.SPECIES_DEPTH_BENCHMARK_GRID.flag
# data/group/xjin/r.proc.gene99_v15-v22-agg75.ACCURACY_BENCHMARK_GRID.flag
# data/group/xjin/r.proc.gene99_v15-v22-agg75.spgc-fit.STRAIN_META_BENCHMARK_GRID.flag

benchmark = {}
depth_meta = {}
spgc_qc = {}
missing = []


for species in species_list:
    for gene_stem in ["gene99_v15-v22-agg75", "gene99_v20-v23-agg75"]:
        depth_path = f"data/group/xjin/species/sp-{species}/r.proc.{gene_stem}.spgc_specgene-ref-filt-p95.species_depth.tsv"
        _depth = pd.read_table(
            depth_path, names=["sample_id", "depth"], index_col="sample_id"
        ).depth
        depth_meta[(species, gene_stem)] = pd.Series(
            dict(species_depth_max=_depth.max(), species_depth_sum=_depth.sum())
        )
        for genome_id, d in genome[lambda x: x.species_id == species].iterrows():
            matched_strain = genome_to_spgc_strain.strain[genome_id]
            spgc_qc_path = f"data/group/xjin/species/sp-{species}/r.proc.gtpro.sfacts-fit.{gene_stem}.spgc-fit.strain_meta-s95-d100-a0-pos100-std25.tsv"
            if os.path.exists(spgc_qc_path):
                qc = pd.read_table(spgc_qc_path, index_col="genome_id")
                if matched_strain in qc.index:
                    spgc_qc[(genome_id, gene_stem)] = qc.loc[matched_strain]

            # SPGC
            for unit in ["uhggtop", "eggnog", "cog"]:
                for tool in [
                    "spgc-fit",
                    # "spgc2-fit",
                    # "nnmatched-m50",
                    # "nnmatched-m10",
                    # "nnmatched-m1",
                    # "nnmatched-m0",
                    "spgc-depth200",
                ]:
                    accuracy_path = f"data/group/xjin/species/sp-{species}/r.proc.{gene_stem}.{tool}.{genome_id}.{unit}-reconstruction_accuracy.tsv"
                    if not os.path.exists(accuracy_path):
                        missing.append(accuracy_path)
                        continue
                    accuracy = (
                        pd.read_table(accuracy_path, index_col="strain")
                        .assign(
                            species=species,
                            accuracy_path=accuracy_path,
                            strain=lambda x: x.index,
                        )
                        .sort_values("f1", ascending=False)
                    )

                    if matched_strain in accuracy.index:
                        benchmark[
                            (genome_id, gene_stem, unit, tool, "match")
                        ] = accuracy.loc[matched_strain]
                    if accuracy.shape[0] >= 1:
                        benchmark[
                            (genome_id, gene_stem, unit, tool, "top")
                        ] = accuracy.iloc[0]
                    if accuracy.shape[0] >= 2:
                        benchmark[
                            (genome_id, gene_stem, unit, tool, "second")
                        ] = accuracy.iloc[1]
                for tool in [
                    "panphlan",
                    # "spanda-s2",
                    # "spanda-s3",
                    # "spanda-s4",
                    # "spanda-s5",
                    "spanda-s6",
                ]:
                    accuracy_path = f"data/group/xjin/species/sp-{species}/r.proc.{gene_stem}.{tool}.{genome_id}.{unit}-reconstruction_accuracy.tsv"
                    if not os.path.exists(accuracy_path):
                        missing.append(accuracy_path)
                        continue
                    accuracy = (
                        pd.read_table(accuracy_path, index_col="strain")
                        .assign(
                            species=species,
                            accuracy_path=accuracy_path,
                            strain=lambda x: x.index,
                        )
                        .sort_values("f1", ascending=False)
                    )
                    if accuracy.shape[0] >= 1:
                        benchmark[
                            (genome_id, gene_stem, unit, tool, "top")
                        ] = accuracy.iloc[0]
                        benchmark[
                            (genome_id, gene_stem, unit, tool, "match")
                        ] = accuracy.iloc[0]
                    if accuracy.shape[0] >= 2:
                        benchmark[
                            (genome_id, gene_stem, unit, tool, "second")
                        ] = accuracy.iloc[1]
                    if accuracy.shape[0] >= 3:
                        benchmark[
                            (genome_id, gene_stem, unit, tool, "third")
                        ] = accuracy.iloc[2]

benchmark = pd.DataFrame(benchmark.values(), index=benchmark.keys()).rename_axis(
    ["genome_id", "gene_stem", "unit", "tool", "match"]
)
depth_meta = pd.DataFrame(depth_meta.values(), index=depth_meta.keys()).rename_axis(
    ["species", "gene_stem"]
)
spgc_qc = pd.DataFrame(spgc_qc.values(), index=spgc_qc.keys()).rename_axis(
    ["genome_id", "gene_stem"]
)
# .assign(
#     qc_code=lambda x: x.apply(assign_qc_code, axis=1)
# )

In [None]:
d_max = depth_meta.species_depth_max.unstack()
d_sum = depth_meta.species_depth_sum.unstack()

plt.scatter("gene99_v15-v22-agg75", "gene99_v20-v23-agg75", data=d_max)
plt.scatter("gene99_v15-v22-agg75", "gene99_v20-v23-agg75", data=d_sum)
plt.plot([1e-2, 1e5], [1e-2, 1e5])
plt.xscale("symlog", linthresh=1e-2)
plt.yscale("symlog", linthresh=1e-2)

In [None]:
# How many additional species had _no_ depth whatsoever?

genome.join(depth_meta.xs("gene99_v20-v23-agg75", level="gene_stem"), on="species_id")[
    lambda x: x.species_depth_sum == 0
]

In [None]:
# And how many had any depth?

genome_filt_list = list(
    genome.join(
        depth_meta.xs("gene99_v20-v23-agg75", level="gene_stem"), on="species_id"
    )[lambda x: (x.species_id != "UNKNOWN") & (x.species_depth_sum > 0)].index
)
species_filt_list = list(genome.loc[genome_filt_list].species_id.unique())

len(genome_filt_list), len(species_filt_list)

### Number of strains per species after filtering

In [None]:
genome.loc[genome_filt_list].species_id.value_counts().value_counts().sort_index()

In [None]:
# How many actual SPGC fits are there?

benchmark.xs(
    ("spgc-fit", "eggnog", "match", "gene99_v20-v23-agg75"),
    level=("tool", "unit", "match", "gene_stem"),
).f1.sort_values()

In [None]:
# Which theoretically estimate-able genomes have no StrainPGC results?

set(genome_filt_list) - set(
    benchmark.xs(
        ("spgc-fit", "gene99_v20-v23-agg75", "eggnog", "match"),
        level=("tool", "gene_stem", "unit", "match"),
    ).index
)

In [None]:
# Which zero-depth species have StrainPanDA results?

set(
    genome.join(
        depth_meta.xs("gene99_v20-v23-agg75", level="gene_stem"), on="species_id"
    )[lambda x: x.species_depth_max.fillna(0) == 0].index
) & set(
    benchmark.xs(
        ("spanda-s6", "gene99_v20-v23-agg75", "eggnog", "match"),
        level=("tool", "gene_stem", "unit", "match"),
    ).index
)

In [None]:
# Which zero-depth species have StrainPanDA results?

set(
    genome.join(
        depth_meta.xs("gene99_v20-v23-agg75", level="gene_stem"), on="species_id"
    )[lambda x: x.species_depth_max.fillna(0) == 0].index
) & set(
    benchmark.xs(
        ("panphlan", "gene99_v20-v23-agg75", "eggnog", "match"),
        level=("tool", "gene_stem", "unit", "match"),
    ).index
)

In [None]:
d = (
    genome.join(genome_to_spgc_strain)[
        lambda x: (x.species_id != "UNKNOWN")
        # & (x.species_depth_sum > 0)
        # & (x.species_depth_max > 0)
        # & ~x.species_id.isin(multi_genome_species)
    ].join(
        benchmark.xs(
            ("spgc-fit", "eggnog", "match", "gene99_v20-v23-agg75"),
            level=("tool", "unit", "match", "gene_stem"),
        ),
        rsuffix="_",
    )
    # .fillna({"precision": 0, "recall": 0, "f1": 0, "jaccard": 0})
)
d

# Statistics

## Species Depth

In [None]:
gene_stem = "gene99_v20-v23-agg75"

species_depth = []
_missing_species = []

for species in tqdm(species_list):
    inpath = f"data/group/xjin/species/sp-{species}/r.proc.{gene_stem}.spgc_specgene-ref-filt-p95.species_depth.tsv"
    if not os.path.exists(inpath):
        _missing_species.append(species)
        continue
    data = pd.read_table(inpath, names=["sample", "depth"]).assign(species=species)
    species_depth.append(data)
species_depth = (
    pd.concat(species_depth)
    .set_index(["sample", "species"])
    .depth.unstack(fill_value=0)
)

print(
    len(_missing_species),
    "out of",
    len(species_list),
    "species are missing.",
)

In [None]:
strain_depth = []
missing_files = []
for species_id in species_depth.columns:
    path = f"data/group/xjin/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.comm.tsv"
    try:
        d = (
            pd.read_table(path, index_col=["sample", "strain"])
            .squeeze()
            .unstack()
            # .rename(columns=str, index=lambda x: "CF_" + str(int(x.split("_")[1])))
            # .rename({'CF_11': 'CF_15', 'CF_15': 'CF_11'})  # Sample swap.
        )
    except FileNotFoundError:
        missing_files.append(path)
        d = pd.DataFrame([])
    _keep_strains = idxwhere(d.sum() > 0.05)
    assert d.index.isin(species_depth.index).all()
    d = d.reindex(index=species_depth.index, columns=_keep_strains, fill_value=0)
    d = d.assign(__other=lambda x: 1 - x.sum(1)).rename(columns={"__other": -1})
    d[d < 0] = 0
    d = d.divide(d.sum(1), axis=0)
    d = d.multiply(species_depth[species_id], axis=0)
    d = d.rename(columns=lambda s: f"{species_id}_{s}")
    strain_depth.append(d)
strain_depth = pd.concat(strain_depth, axis=1)
strain_rabund = strain_depth.divide(strain_depth.sum(1), axis=0)
len(species_depth.columns), len(missing_files)

In [None]:
xjin_strain_depth = strain_depth[lambda x: x.index.str.startswith("xjin_")]
xjin_strain_depth = xjin_strain_depth[idxwhere((xjin_strain_depth.sum() > 0.01))]
plt.hist(np.log10(xjin_strain_depth.sum()), bins=20)

## Number of benchmark samples

In [None]:
species_depth.shape[0]

In [None]:
(species_depth.max().sort_values() > 0).sum()

### Figure 2A (Species in Phylogeny)

In [None]:
with open("fig/xjin_benchmark_species_list.txt", "w") as f:
    for species in idxwhere((species_depth.max().sort_values() > 0)):
        print(species, file=f)

In [None]:
genome.loc[lambda x: x.species_id.isin(idxwhere(species_depth.max().sort_values() > 0))]

In [None]:
xjin_sotu_strain_list = idxwhere(xjin_strain_depth.sum() > 1)
other_strain_depth = xjin_strain_depth.drop(columns=xjin_sotu_strain_list).sum(1)

d = species_depth.drop(columns=idxwhere(species_depth.max() <= 0))
sns.clustermap(
    d.T,
    norm=mpl.colors.SymLogNorm(linthresh=0.1),
    metric="cosine",
    figsize=(5, 8),
    xticklabels=0,
    yticklabels=0,
)

In [None]:
xjin_species_depth = species_depth[lambda x: x.index.str.startswith("xjin_")].rename(
    columns=str
)
xjin_species_depth = xjin_species_depth[idxwhere((xjin_species_depth.sum() > 0.01))]
plt.hist(np.log10(xjin_species_depth.sum()), bins=20)

## Depth Distribution of Species in Benchmark Samples

In [None]:
xjin_species_depth

#### Figure 2B

In [None]:
d0 = xjin_species_depth.reindex(species_filt_list, axis=1, fill_value=0)
print(d0.max().quantile([0.25, 0.5, 0.75]))
_num_genomes_per_species = (
    genome_to_spgc_strain.assign(species=genome.species_id)
    .species.value_counts()
    .reindex(
        d0.columns, fill_value=0
    )  # TODO: Do we only want to consider these 85 species, or the larger list?
)
_num_genomes_order = range(1, _num_genomes_per_species.max() + 1)
_num_genomes_palette = lib.plot.construct_ordered_palette(
    _num_genomes_order, cm="plasma", other="white"
)
_num_genomes_size_and_marker_palette = dict(
    zip(_num_genomes_order, [(15, "o"), (30, "s"), (70, "p"), (120, "*")])
)
# _xjin_species_order = d.median().sort_values(ascending=False).index

d1 = pd.DataFrame(
    dict(
        _max=d0.max(),
        _sum=d0.sum(),
        num_species_genomes=_num_genomes_per_species,
        p__=d0.columns.map(species_taxonomy.p__),
    )
).assign(
    depth_in_other_samples=lambda x: x._sum - x._max,
    # num_genomes_c=lambda x: x.num_species_genomes.map(_num_genomes_palette),
    phylum_c=lambda x: x.p__.map(phylum_palette),
    # num_genomes_s=lambda x: 50 * np.sqrt(x.num_species_genomes)
)

fig, ax = plt.subplots(figsize=(3, 4.5))

phylum_zorder = species_taxonomy.loc[species_list].p__.value_counts().index

for num_species_genomes in _num_genomes_order:
    for phylum in phylum_zorder:
        d2 = d1[
            lambda x: (x.num_species_genomes == num_species_genomes) & (x.p__ == phylum)
        ]
        markersize, markershape = _num_genomes_size_and_marker_palette[
            num_species_genomes
        ]
        ax.scatter(
            "_sum",
            "_max",
            s=markersize,
            facecolors="phylum_c",
            lw=0.5,
            data=d2,
            marker=markershape,
            edgecolor="black",
            label="__nolegend__",
            alpha=0.85,
        )
    ax.scatter(
        [],
        [],
        s=markersize,
        facecolors="grey",
        lw=0.5,
        data=d2,
        marker=markershape,
        edgecolor="black",
        label=num_species_genomes,
    )
ax.plot([1e-2, 1e6], [1e-2, 1e6], lw=1, linestyle="--", color="k")
ax.set_yscale("log", subs=[])
ax.set_xscale("log", subs=[])
ax.set_ylim(1.2e-2, 1e4)
ax.set_xlim(1.2e-2, 1e6)
ax.set_aspect(1)
ax.set_xlabel("Total Depth")
ax.set_ylabel("Max Depth")

leg = ax.legend(loc="upper left", title="Num. Strains", labelspacing=0.1, frameon=False)
leg._legend_box.align = "left"

fig.savefig("fig/fig2b_species_depth.pdf", bbox_inches="tight")

In [None]:
d1.groupby("p__").num_species_genomes.sum()

## How many strains/species/phyla are analyzed in the benchmark?

In [None]:
# Full list

print("total_num_benchmark_genomes:", len(genome_list))
print("total_num_benchmark_species", len(species_list))
print("total_num_benchmark_phyla", len(species_taxonomy.loc[species_list].p__.unique()))

In [None]:
# Filtered list

len(genome_filt_list), len(species_filt_list), len(
    species_taxonomy.loc[species_filt_list].p__.unique()
)

print("filt_num_benchmark_genomes:", len(genome_filt_list))
print("filt_num_benchmark_species", len(species_filt_list))
print(
    "filt_num_benchmark_phyla",
    len(species_taxonomy.loc[species_filt_list].p__.unique()),
)

In [None]:
# Strains that were excluded with comments:

for (genome_id, species_id, comments) in genome0.drop(genome_filt_list).reset_index()[['genome_id', 'species_id', 'comments']].values:
    print(genome_id, species_id, comments, sep='\t')
    print()

## StrainPGC Performance

In [None]:
tool = "spgc-fit"
unit = "eggnog"
match = "match"
gene_stem = "gene99_v20-v23-agg75"

spgc_benchmark_genome_performance_and_metadata = (
    genome_to_spgc_strain.loc[genome_filt_list]
    .join(
        benchmark.xs((unit, match, gene_stem), level=("unit", "match", "gene_stem")),
        rsuffix="_",
    )[["precision", "recall", "f1"]]
    .unstack("tool")
    .xs(tool, level="tool", axis="columns")
    .fillna(0)
    .join(genome_to_spgc_strain.species_gene_frac)
)

# Median and interquartile range of StrainPGC performance
print(
    spgc_benchmark_genome_performance_and_metadata[
        ["precision", "recall", "f1"]
    ].quantile([0.25, 0.5, 0.75])
)

## How does it compare to other tools?

In [None]:
# Absolute F1 change relative to other tools (>0 means improvement)

unit = "eggnog"
match = "match"
gene_stem = "gene99_v20-v23-agg75"

tool_by_genome_f1_matrix = (
    genome_to_spgc_strain.loc[genome_filt_list]
    .join(
        benchmark.xs((unit, match, gene_stem), level=("unit", "match", "gene_stem")),
        rsuffix="_",
    )
    .f1.unstack("tool")
    .fillna(0)
)

# Median and interquartile range of comparison to StrainPGC performance
print(tool_by_genome_f1_matrix.quantile([0.25, 0.5, 0.75]))
print()
print(
    (
        -(
            tool_by_genome_f1_matrix.subtract(
                tool_by_genome_f1_matrix["spgc-fit"], axis=0
            )
        )
    ).quantile([0.25, 0.5, 0.75])
)

In [None]:
((tool_by_genome_f1_matrix["spgc-fit"] - tool_by_genome_f1_matrix.T).T > 0).mean()

In [None]:
# Ratio of false-discovery rates (1 - precision) relative to other tools (<1 means improvement)

unit = "eggnog"
match = "match"
gene_stem = "gene99_v20-v23-agg75"

tool_by_genome_fdr_matrix = (
    genome_to_spgc_strain.loc[genome_filt_list]
    .join(
        benchmark.xs((unit, match, gene_stem), level=("unit", "match", "gene_stem")),
        rsuffix="_",
    )
    .assign(fdr=lambda x: 1 - x.precision)
    .fdr.unstack("tool")
    .fillna(0)
)

print(tool_by_genome_fdr_matrix.quantile([0.25, 0.5, 0.75]))
print()
print(
    (
        1
        / tool_by_genome_fdr_matrix.divide(
            tool_by_genome_fdr_matrix["spgc-fit"], axis=0
        )
    ).quantile([0.25, 0.5, 0.75])
)

In [None]:
# Ratio of false-negative rates (1 - recall) relative to other tools (<1 means improvement)

unit = "eggnog"
match = "match"
gene_stem = "gene99_v20-v23-agg75"

tool_by_genome_fnr_matrix = (
    genome_to_spgc_strain.loc[genome_filt_list]
    .join(
        benchmark.xs((unit, match, gene_stem), level=("unit", "match", "gene_stem")),
        rsuffix="_",
    )
    .assign(fnr=lambda x: 1 - x.recall)
    .fnr.unstack("tool")
    .fillna(0)
)

print(tool_by_genome_fnr_matrix.quantile([0.25, 0.5, 0.75]))
print()
print(
    (
        1
        / tool_by_genome_fnr_matrix.divide(
            tool_by_genome_fnr_matrix["spgc-fit"], axis=0
        )
    ).quantile([0.25, 0.5, 0.75])
)

In [None]:
# Ratio of (1 - F1) relative to other tools (<1 means improvement)

unit = "eggnog"
match = "match"
gene_stem = "gene99_v20-v23-agg75"

tool_by_genome_f1c_matrix = (
    genome_to_spgc_strain.loc[genome_filt_list]
    .join(
        benchmark.xs((unit, match, gene_stem), level=("unit", "match", "gene_stem")),
        rsuffix="_",
    )
    .assign(f1c=lambda x: 1 - x.f1)
    .f1c.unstack("tool")
    .fillna(0)
)

print(tool_by_genome_f1c_matrix.quantile([0.25, 0.5, 0.75]))
print()
print(
    (
        1
        / tool_by_genome_f1c_matrix.divide(
            tool_by_genome_f1c_matrix["spgc-fit"], axis=0
        )
    ).quantile([0.25, 0.5, 0.75])
)

#### Figure 2C

In [None]:
unit = "eggnog"
match = "match"
gene_stem = "gene99_v20-v23-agg75"


qc_code_palette = {
    0: "grey",
    1: "lightgreen",
    2: "lightblue",
    3: "lightsalmon",
    -1: "grey",
}

cmap = "Greys"
norm = mpl.colors.LogNorm(vmin=0.2, vmax=50)

d0 = benchmark.xs((unit, match, gene_stem), level=("unit", "match", "gene_stem"))

_tool_comparison_order = [
    "panphlan",
    "spanda-s6",
    # "spgc-depth200",
]
_score_order = ["precision", "recall", "f1"]
fig, axs = plt.subplots(
    len(_score_order),
    len(_tool_comparison_order),
    figsize=(2 * len(_tool_comparison_order), 2 * len(_score_order)),
    sharex=True,
    sharey=True,
    gridspec_kw=dict(hspace=0.1, wspace=0.1),
    squeeze=False,
)
y_tool = "spgc-fit"
for score, ax_row in zip(_score_order, axs):
    d1 = (
        d0[score]
        .unstack()
        .fillna(0)
        .join(genome_to_spgc_strain)
        .assign(
            c=lambda x: x.qc_code.astype(int).map(
                qc_code_palette,
            )
        )
        .sort_values("qc_code")
        .loc[genome_filt_list]
    )
    print(
        score, "SPGC IQR:", (d1[y_tool]).quantile([0.25, 0.5, 0.75]).round(3).tolist()
    )
    for x_tool, ax in zip(_tool_comparison_order, ax_row):
        # if score == 'precision':
        #     assert False
        print(
            "compared to: {}, {}, {:.1g}".format(
                x_tool,
                (d1[y_tool] - d1[x_tool])
                .quantile([0.05, 0.25, 0.5, 0.75, 0.95])
                .round(3)
                .tolist(),
                sp.stats.wilcoxon(d1[x_tool], d1[y_tool]).pvalue,
            )
        )
        art = ax.hist2d(
            x_tool,
            y_tool,
            data=d1,
            # # c="c",
            # lw=1,
            # edgecolor="grey",
            # facecolor='none',
            # s=30,
            # alpha=0.75,
            # label="__nolegend__",
            cmap=cmap,
            norm=norm,
            bins=np.linspace(0, 1, num=21),
        )
        ax.plot([0, 1], [0, 1], lw=0.5, linestyle="--", color="k")
    print()

for _score_order, ax in zip(_score_order, axs[:, 0]):
    ax.set_ylabel(_score_order)
    ax.set_yticks([0, 0.5, 1])
    ax.set_ylim(-0.05, 1.05)

for x_tool, ax in zip(_tool_comparison_order, axs[-1, :]):
    ax.set_xlabel(x_tool)
    ax.set_xticks([0, 0.5, 1])
    ax.set_xlim(-0.05, 1.05)

fig.savefig("fig/fig2c_benchmark.pdf", bbox_inches="tight")

fig, ax = plt.subplots(figsize=(0.1, 1.25))
fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), cax=ax)
ax.set_yscale("log", subs=[])
ax.set_yticks([1, 10, 50])

fig.savefig("fig/fig2c_benchmark_cbar.pdf", bbox_inches="tight")

In [None]:
# FIXME: Load data as in the main-text figure.

unit = "eggnog"
match = "match"
gene_stem = "gene99_v20-v23-agg75"

_tool_list = [
    "spgc-fit",
    "panphlan",
    "spanda-s6",
    # "spgc-depth200",
    # "nnmatched-m50"
]
_score_list = ["precision", "recall", "f1"]
xvar = "species_depth_max"

# _tool_palette = lib.plot.construct_ordered_palette(_tool_list)
_tool_palette = lib.plot.construct_ordered_palette_from_list(
    _tool_list, colors=["tab:blue", "tab:green", "tab:orange", "tab:purple"]
)


d0 = (
    benchmark.xs((unit, match, gene_stem), level=("unit", "match", "gene_stem"))[
        _score_list
    ]
    .reindex(product(genome_filt_list, _tool_list))
    .fillna(0)
    .join(genome)
    # [lambda x: ~x.species_id.isin(multi_genome_species)]
    .join(depth_meta.xs("gene99_v20-v23-agg75", level="gene_stem"), on="species_id")
    .join(genome_to_spgc_strain, rsuffix="_")
)

fig, axs = plt.subplots(
    len(_score_list),
    figsize=(5, 3 * len(_score_list)),
    sharex=True,
    sharey=True,
)

for _tool in _tool_list:
    d1 = d0.xs(_tool, level="tool").loc[genome_filt_list]
    for _score, ax in zip(_score_list, axs.flatten()):
        fit = smf.ols(f'{_score} ~ cr({xvar}, 5, constraints="center")', data=d1).fit()
        d2 = d1.assign(
            lowess_y=lambda d: sm.nonparametric.lowess(
                d[_score], d[xvar], it=10, frac=1 / 3, return_sorted=False
            ),
            spline_y=lambda d: fit.predict(pd.DataFrame({xvar: d[xvar]})),
            xvar_rolling_average=lambda d: d.sort_values(xvar)
            .rolling(window=10)[xvar]
            .mean(),
            score_rolling_average=lambda d: d.sort_values(xvar)
            .rolling(window=10)[_score]
            .mean(),
        )
        smoothed = pd.DataFrame({xvar: np.logspace(-0.5, 4, num=100)}).assign(
            spline_y=lambda d: fit.predict(pd.DataFrame({xvar: d[xvar]})),
            lowess_y=lambda d: sm.nonparametric.lowess(
                d2[_score],
                d2[xvar],
                xvals=d[xvar],
                it=10,
                frac=1 / 3,
                return_sorted=False,
            ),
        )
        ax.scatter(
            xvar,
            _score,
            data=d2,
            label=_tool,
            s=8,
            alpha=0.5,
            facecolor="none",
            color=_tool_palette[_tool],
        )
        print(_tool, _score, sp.stats.spearmanr(d2[xvar], d2[_score]))
        ax.plot(
            "xvar_rolling_average",
            "score_rolling_average",
            data=d2.sort_values(xvar),
            label="__nolegend__",
            lw=1,
            linestyle="-",
            color=_tool_palette[_tool],
        )

for _score, ax in zip(_score_list, axs.flatten()):
    ax.set_ylabel(_score)

ax.set_xscale("symlog", linthresh=1e-1, linscale=0.1)

pad = 1e-2
# ax.set_yscale(
#     "function",
#     functions=(lambda x: -np.log(1 + pad - x), lambda y: (1 + pad - np.exp(-y))),
# )
# ax.set_yticks(np.unique([0.0, 0.5, 0.75, 0.9, 0.95, 0.99]))
ax.set_ylim(-0.05, 1.05)

axs[2].legend(loc="lower right")

print(len(d2))

In [None]:
# FIXME: Load data as in the main-text figure.

unit = "eggnog"
match = "match"
gene_stem = "gene99_v20-v23-agg75"

_tool_list = [
    "spgc-fit",
    "panphlan",
    "spanda-s6",
    # "spgc-depth200",
    # "nnmatched-m50"
]
_score_list = ["precision", "recall", "f1"]

# _tool_palette = lib.plot.construct_ordered_palette(_tool_list)
_tool_palette = lib.plot.construct_ordered_palette_from_list(
    _tool_list, colors=["tab:blue", "tab:green", "tab:orange", "tab:purple"]
)


d0 = (
    benchmark.xs((unit, match, gene_stem), level=("unit", "match", "gene_stem"))[
        _score_list
    ]
    .reindex(product(genome_filt_list, _tool_list))
    .fillna(0)
    .join(genome)
    # [lambda x: ~x.species_id.isin(multi_genome_species)]
    .join(depth_meta.xs("gene99_v20-v23-agg75", level="gene_stem"), on="species_id")
    .join(genome_to_spgc_strain, rsuffix="_")
)


_xvar_and_thresh_list = [("max_strain_depth", 1), ("num_strain_sample", 5)]

fig, axs = plt.subplots(
    len(_score_list), len(_xvar_list),
    figsize=(5 * len(_xvar_and_thresh_list), 3 * len(_score_list)),
    sharex='col',
    sharey=True,
)

for (xvar, xvar_thresh), axs_col in zip(_xvar_and_thresh_list, axs.T):
    for _tool in _tool_list:
        d1 = d0.xs(_tool, level="tool").loc[genome_filt_list]
        for _score, ax in zip(_score_list, axs_col):
            fit = smf.ols(f'{_score} ~ cr({xvar}, 5, constraints="center")', data=d1).fit()
            d2 = d1.assign(
                lowess_y=lambda d: sm.nonparametric.lowess(
                    d[_score], d[xvar], it=10, frac=1 / 3, return_sorted=False
                ),
                spline_y=lambda d: fit.predict(pd.DataFrame({xvar: d[xvar]})),
                xvar_rolling_average=lambda d: d.sort_values(xvar)
                .rolling(window=10)[xvar]
                .mean(),
                score_rolling_average=lambda d: d.sort_values(xvar)
                .rolling(window=10)[_score]
                .mean(),
            )
            smoothed = pd.DataFrame({xvar: np.logspace(-0.5, 4, num=100)}).assign(
                spline_y=lambda d: fit.predict(pd.DataFrame({xvar: d[xvar]})),
                lowess_y=lambda d: sm.nonparametric.lowess(
                    d2[_score],
                    d2[xvar],
                    xvals=d[xvar],
                    it=10,
                    frac=1 / 3,
                    return_sorted=False,
                ),
            )
            ax.scatter(
                xvar,
                _score,
                data=d2,
                label=_tool,
                s=8,
                alpha=0.5,
                facecolor="none",
                color=_tool_palette[_tool],
            )
            print(xvar, _tool, _score, sp.stats.spearmanr(d2[xvar], d2[_score]))
            ax.plot(
                "xvar_rolling_average",
                "score_rolling_average",
                data=d2.sort_values(xvar),
                label="__nolegend__",
                lw=1,
                linestyle="-",
                color=_tool_palette[_tool],
            )
            ax.set_xscale("symlog", linthresh=1e-1, linscale=0.1)
            pad = 1e-2
            ax.set_ylim(-0.05, 1.05)
            ax.axvline(xvar_thresh, lw=1, linestyle='--', color='k')
    
for _score, ax in zip(_score_list, axs[:,0]):
    ax.set_ylabel(_score)

for (xvar, xvar_thresh), ax in zip(_xvar_and_thresh_list, axs[-1,:]):
    ax.set_xlabel(xvar)

axs[2, 0].legend(loc="lower right")

print(len(d2))

fig.savefig("fig/fig_s1.pdf")

In [None]:
d0 = (
    benchmark.xs((unit, match, gene_stem), level=("unit", "match", "gene_stem"))[
        _score_list
    ]
    .reindex(product(genome_filt_list, _tool_list))
    .fillna(0)
    .join(genome)
    # [lambda x: ~x.species_id.isin(multi_genome_species)]
    .join(depth_meta.xs("gene99_v20-v23-agg75", level="gene_stem"), on="species_id")
    .join(genome_to_spgc_strain, rsuffix="_")
)

plt.scatter(
    "num_strain_sample", "sum_strain_depth", c="f1", data=d0, vmin=0.4, vmax=1.0
)
plt.colorbar()
plt.yscale("log")
plt.xscale("log")
plt.xlabel("Num. Strain-pure Samples")
plt.ylabel("Total Strain Depth")

# fig.savefig("fig/num_samples_by_total_depth.pdf")

## Quality Control Stats

In [None]:
unit = "eggnog"
match = "match"
tool = "spgc-fit"
gene_stem = "gene99_v20-v23-agg75"

d = (
    benchmark.xs(
        (unit, match, tool, gene_stem), level=("unit", "match", "tool", "gene_stem")
    )
    .reindex(genome_filt_list)
    .join(genome_to_spgc_strain, rsuffix="_")
    .assign(
        log_sum_strain_depth=lambda x: np.log10(x.sum_strain_depth),
        # log_max_strain_depth=lambda x: np.log10(x.max_strain_depth),
    )
    .loc[genome_filt_list]
)

print(
    d.shape
)  # NOTE: There is one species (Collinsiella?) where strain deconvolution failed, and which therefore has no SPGC strains.

fig, axs = plt.subplots(
    3, 3, squeeze=False, figsize=(10, 10), sharex="col", sharey="row"
)

for (x, y), ax in zip(
    product(
        [
            "log_selected_gene_depth_ratio_std",
            "species_gene_frac",
            "log_sum_strain_depth",
            # "log_max_strain_depth",
        ],
        ["precision", "recall", "f1"],
    ),
    axs.T.flatten(),
):
    print((x, y), sp.stats.spearmanr(d[x], d[y]))
    ax.scatter(x, y, data=d, s=15, alpha=0.5)
    ax.set_xlabel(x)
    ax.set_ylabel(y)

# (d1[lambda x: x[('precision', 'spgc-fit')] < 0.7])

In [None]:
unit = "eggnog"
match = "match"
tool = "spgc-fit"
gene_stem = "gene99_v20-v23-agg75"


d = (
    benchmark.xs(
        (unit, match, tool, gene_stem), level=("unit", "match", "tool", "gene_stem")
    )[["precision", "recall", "f1"]]
    .reindex(genome_filt_list, fill_value=0)
    .join(genome_to_spgc_strain, rsuffix="_")
    .assign(
        log_sum_strain_depth=lambda x: np.log10(x.sum_strain_depth),
        # log_max_strain_depth=lambda x: np.log10(x.max_strain_depth),
    )
    .loc[genome_filt_list]
)
low_quality_strains = idxwhere(~d.passes_filter)

In [None]:
bins = np.linspace(0, 1)

plt.hist(d.loc[low_quality_strains].f1, bins=bins)
plt.hist(d.drop(low_quality_strains).f1, bins=bins, alpha=0.5)

In [None]:
print(d.loc[low_quality_strains].f1.quantile([0.25, 0.5, 0.75]))
print()
print(d.drop(low_quality_strains).f1.quantile([0.25, 0.5, 0.75]))
print()
print(
    sp.stats.mannwhitneyu(d.loc[low_quality_strains].f1, d.drop(low_quality_strains).f1)
)

In [None]:
print(d.loc[low_quality_strains].precision.quantile([0.25, 0.5, 0.75]))
print()
print(d.drop(low_quality_strains).precision.quantile([0.25, 0.5, 0.75]))
print()
print(
    sp.stats.mannwhitneyu(
        d.loc[low_quality_strains].precision, d.drop(low_quality_strains).precision
    )
)

In [None]:
print(d.loc[low_quality_strains].recall.quantile([0.25, 0.5, 0.75]))
print()
print(d.drop(low_quality_strains).recall.quantile([0.25, 0.5, 0.75]))
print()
print(
    sp.stats.mannwhitneyu(
        d.loc[low_quality_strains].recall, d.drop(low_quality_strains).recall
    )
)

## Relationship between thresholds and performance

In [None]:
# data/group/xjin/r.proc.gene99_v15-v22-agg75.spgc_specgene-ref-t25-p95.SPECIES_DEPTH_BENCHMARK_GRID.flag
# data/group/xjin/r.proc.gene99_v15-v22-agg75.ACCURACY_BENCHMARK_GRID.flag
# data/group/xjin/r.proc.gene99_v15-v22-agg75.spgc-fit.STRAIN_META_BENCHMARK_GRID.flag

thresh_benchmark = {}
missing = []

gene_stem = "gene99_v20-v23-agg75"
tool = "spgc-fit"
unit = "eggnog"

for species in species_list:
    for genome_id, d in genome[lambda x: x.species_id == species].iterrows():
        matched_strain = genome_to_spgc_strain.strain[genome_id]
        for depth_thresh, corr_thresh in product(
            ["50", "100", "150", "200", "250", "300", "350"],
            [
                "0",
                "50",
                "100",
                "150",
                "200",
                "250",
                "300",
                "350",
                "400",
                "450",
                "500",
                "550",
            ],
        ):
            depth_thresh_numeric = int(depth_thresh) / 1000
            corr_thresh_numeric = int(corr_thresh) / 1000
            accuracy_path = f"data/group/xjin/species/sp-{species}/r.proc.gtpro.sfacts-fit.{gene_stem}.spgc_specgene-ref-filt-p95_ss-all_t-10_thresh-corr{corr_thresh}-depth{depth_thresh}.{genome_id}.{unit}-reconstruction_accuracy.tsv"
            if not os.path.exists(accuracy_path):
                missing.append(accuracy_path)
                continue
            accuracy = (
                pd.read_table(accuracy_path, index_col="strain")
                .assign(
                    species=species,
                    accuracy_path=accuracy_path,
                    strain=lambda x: x.index,
                )
                .sort_values("f1", ascending=False)
            )

            if matched_strain in accuracy.index:
                thresh_benchmark[
                    (
                        genome_id,
                        gene_stem,
                        unit,
                        tool,
                        depth_thresh_numeric,
                        corr_thresh_numeric,
                        "match",
                    )
                ] = accuracy.loc[matched_strain]
            if accuracy.shape[0] >= 1:
                thresh_benchmark[
                    (
                        genome_id,
                        gene_stem,
                        unit,
                        tool,
                        depth_thresh_numeric,
                        corr_thresh_numeric,
                        "top",
                    )
                ] = accuracy.iloc[0]


thresh_benchmark = pd.DataFrame(
    thresh_benchmark.values(), index=thresh_benchmark.keys()
).rename_axis(
    ["genome_id", "gene_stem", "unit", "tool", "depth_thresh", "corr_thresh", "match"]
)

In [None]:
d = (
    thresh_benchmark.assign(
        weighted_score=lambda x: sp.stats.hmean(
            x[["precision", "recall"]].values, weights=[2, 1], axis=1
        )
    )
    .xs(
        ("gene99_v20-v23-agg75", "eggnog", "spgc-fit", "match"),
        level=["gene_stem", "unit", "tool", "match"],
    )
    .groupby(level=("depth_thresh", "corr_thresh"))[
        ["precision", "recall", "f1", "weighted_score"]
    ]
    .median()
)

print(d.loc[(0.2, 0.4)])


fig, axs = plt.subplots(2, 2, figsize=(12, 12), sharex=True, sharey=True)
for score, ax in zip(
    [
        "precision",
        "recall",
        "f1",
        # "weighted_score",
    ],
    axs.flatten(),
):
    sns.heatmap(
        d[score].unstack(),
        annot=d[score].unstack(),
        fmt=".2f",
        annot_kws={"fontsize": 7},
        # norm=mpl.colors.PowerNorm(1 / 1, vmin=0.7, vmax=1),
        ax=ax,
        cbar=False,
    )
    ax.set_title(score)
    ax.invert_yaxis()

fig.savefig("fig/fig_s3.pdf")

d.sort_values('f1', ascending=False).head(10)c