## Preamble

### Project Template

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

### Imports

In [None]:
import os
import subprocess
import sys
import time
from datetime import datetime
from glob import glob
from itertools import chain, product
from tempfile import mkstemp

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import statsmodels.api as sm
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

In [None]:
genome = pd.read_table("meta/genome.tsv", index_col="genome_id")
assert genome.index.is_unique
genome

In [None]:
species_list = list(
    genome[
        lambda x: (x.species_id != "TODO") & (x.genome_path != "")
    ].species_id.unique()
)
len(species_list)

In [None]:
xjin_sample_list = pd.read_table('meta/XJIN_BENCHMARK/mgen.tsv', index_col='mgen_id').index.to_list()

In [None]:
_depth = pd.read_table(depth_path, names=['sample_id', 'depth'], index_col='sample_id').depth.loc[xjin_sample_list]
_depth

In [None]:
path

In [None]:
unit = "uhggtop"

benchmark = []
missing = []
depth_meta = {}
for species in species_list:
    depth_path = f"data/group/xjin_ucfmt_hmp2/species/sp-{species}/r.proc.gene99_new-v22-agg75.spgc_specgene-ref-t25-p95.species_depth.tsv"
    _depth = pd.read_table(depth_path, names=['sample_id', 'depth'], index_col='sample_id').depth.loc[xjin_sample_list]
    depth_meta[species] = pd.Series(dict(max_depth=_depth.max(), sum_depth=_depth.sum()))
    for tool in ["spgc-fit", "spanda-s2", "panphlan"]:
        accuracy_path = f"data/group/XJIN_BENCHMARK/species/sp-{species}/r.proc.gene99_new-v22-agg75.{tool}.{unit}-xjin_strain_summary.tsv"
        try:
            benchmark.append(
                pd.read_table(accuracy_path)
                .assign(
                    species=species,
                    tool=tool,
                )
                .sort_values("f1", ascending=False)
                .groupby(["genome_id"])
                .head(1)
            )
        except FileNotFoundError:
            missing.append(path)
            continue
benchmark = pd.concat(benchmark)
depth_meta = pd.DataFrame(depth_meta).T

benchmark.sort_values(["species", "tool"])

In [None]:
multi_genome_species = idxwhere(genome["species_id"].value_counts() > 1)
multi_genome_species

In [None]:
excluded_species = ["100878"]

In [None]:
considered_species = list(
    set(species_list) - set(multi_genome_species) - set(excluded_species)
)

In [None]:
benchmark[lambda x: x.species.isin(considered_species)][
    ["species", "tool"]
].value_counts().sort_values(ascending=False)
# benchmark[lambda x: x.species == '100196']

In [None]:
d0 = (
    benchmark[lambda x: x.species.isin(considered_species)]
    .set_index(["genome_id", "tool"])
    .f1.unstack(fill_value=0)
    .assign(species=genome.species_id)
    .join(depth_meta, on='species')
)

plt.scatter("panphlan", "spgc-fit", data=d0, c='max_depth', alpha=0.4, norm=mpl.colors.LogNorm())
# plt.scatter(d1['spanda-s2'], d1['spgc-fit'])
plt.plot([0, 1], [0, 1], lw=1, linestyle="--", color="k")

In [None]:
d1 = (
    benchmark[lambda x: x.species.isin(considered_species)]
    .set_index(["genome_id", "tool"])[["f1", "precision", "recall"]]
    .unstack(fill_value=0)
    .rename_axis(columns=["score", "tool"])
    .fillna(0)
)

### Figure A ###
_tool_comparison_order = ["panphlan", "spanda-s2"]
_score_order = ["precision", "recall", "f1"]
fig, axs = plt.subplots(
    len(_score_order),
    len(_tool_comparison_order),
    figsize=(2.5 * len(_tool_comparison_order) + 1.5, 2.5 * len(_score_order)),
    sharex=True,
    sharey=True,
)
y = "spgc-fit"
nbins = 15
left_bound = 0.0
bins = [0] + list(np.linspace(left_bound, 1, num=nbins + 1)[1:])
for score, ax_row in zip(_score_order, axs):
    d2 = d1.xs(score, level="score", axis="columns")
    print(score, "SPGC IQR:", (d2[y]).quantile([0.25, 0.5, 0.75]).round(3).tolist())
    for x, ax in zip(_tool_comparison_order, ax_row):
        *_, cbar_artist = ax.hist2d(
            x,
            y,
            data=d2,
            bins=bins,
            cmin=1,
            # norm=mpl.colors.PowerNorm(1 / 2, vmin=0, vmax=64),
            cmap="magma_r",
        )
        ax.plot([left_bound, 1], [left_bound, 1], lw=1, linestyle="--", color="k")
        ax.set_aspect(1)
        print(
            "compared to:",
            x,
            (d2[y] - d2[x]).quantile([0.25, 0.5, 0.75]).round(3).tolist(),
            sp.stats.wilcoxon(d2[x], d2[y]).pvalue,
        )

ax.set_xlim(left_bound, 1)
ax.set_ylim(left_bound, 1)

for score, ax_row in zip(_score_order, axs):
    ax_row[0].set_ylabel(score)
    # *_, artist = ax_row[-1].hist2d(x, y, data=d1.head(0), bins=np.linspace(0, 1, num=21), norm=mpl.colors.PowerNorm(1/2, vmin=0, vmax=30), cmap='Blues')
    # fig.colorbar(artist, cax=ax_row[-1])

for tool_comparison, ax_col in zip(_tool_comparison_order, axs.T):
    ax_col[0].set_title(tool_comparison)

fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.025, 0.67])
fig.colorbar(
    cbar_artist, cax=cbar_ax,
    # ticks=[0, 1, 2, 4, 8, 16, 32, 64],
    label="count strains"
)
# fig.tight_layout(rect=(0, 0, 0.85, 0.67))
# ax.set_xlabel(x)
# ax.set_ylabel(y)

In [None]:
d0

In [None]:
d0 = benchmark.join(depth_meta, on='species')

tool = 'spgc-fit'
d1 = d0[lambda x: x.tool == tool]
plt.scatter('max_depth', 'f1', data=d1)

tool = 'panphlan'
d1 = d0[lambda x: x.tool == tool]
plt.scatter('max_depth', 'f1', data=d1)

tool = 'spanda-s2'
d1 = d0[lambda x: x.tool == tool]
plt.scatter('max_depth', 'f1', data=d1)


plt.xscale('log')