### Preamble

#### Project Template

In [None]:
%load_ext autoreload

In [None]:
import os as _os
_os.chdir(_os.environ['PROJECT_ROOT'])
_os.path.realpath(_os.path.curdir)

#### Imports

In [None]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
from lib.pandas_util import idxwhere, aligned_index, align_indexes, invert_mapping
import matplotlib as mpl
import lib.plot
import statsmodels.api as sm
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm
import subprocess
from tempfile import mkstemp
import time
import subprocess
from itertools import chain
import os
from itertools import product
from mpl_toolkits.axes_grid1 import make_axes_locatable
from datetime import datetime
import sys

In [None]:
import sfacts as sf

In [None]:
import lib.thisproject.data

#### Set Style

In [None]:
sns.set_context('talk')
plt.rcParams['figure.dpi'] = 50

## Set Parameters / Load Data

In [None]:
# Fixed params

group='XJIN_BENCHMARK'
stemA='r.proc'
centroid=75
stemB = 'sfacts-fit'
gene_params = f"99-v22-agg{centroid}"
# depth_thresh="250"
# specgene_params='specgene-ref-t25-p90'

In [None]:
ref_strains = pd.read_table('meta/genome.tsv', index_col='genome_id')[lambda x: ~x.genome_path.isna()]
species_list = [x for x in ref_strains.species_id.unique() if x != "TODO"]

In [None]:
EMPTY_DATA = pd.Series(dict(
            strain=None,
            precision=np.nan,
            recall=np.nan,
            f1=np.nan,
            genome_id=None,
            correlation_thresh=np.nan,
            depth_thresh=np.nan,
            dominant_strain=None,
            total_num_reference_genomes=-1,
            total_num_xjin_strains=-1,
            num_species_free_samples=-1,
            num_strain_samples=0,
            strain_depth_sum=np.nan,
            strain_depth_max=np.nan,
            # strain_depth_std=np.nan,
        )).to_frame().T

In [None]:
seed = 0
max_strain_samples_order = [1, 10, 20, 999]

species_depth = {}
# SPGC
missing_tally = 0
empty_tally = 0
for (
    species,
    max_strain_samples,
) in tqdm(list(product(
    species_list,
    max_strain_samples_order,
)), ncols=50):
    path = f"data/group/xjin_hmp2/species/sp-{species}/{stemA}.gene99-v22-agg75.spgc_specgene-ref-t25-p95.species_depth.tsv"
    if os.path.exists(path):
        mtime = os.path.getmtime(path)
        d = pd.read_table(path, names=['sample', 'depth'], index_col='sample').depth
    else:
        missing_tally += 1
    if d.empty:
        empty_tally += 1
    species_depth[species] = d

species_depth = pd.DataFrame(species_depth).fillna(0)[lambda x: x.index.str.startswith('xjin_')]
max_species_depth = species_depth.max()

print(f"For species depth, {missing_tally} files were missing and {empty_tally} files were empty.", file=sys.stderr)

In [None]:
# TODO: Make these loop variables
seed = 0
max_strain_samples_order = [1, 10, 20, 999]
unit_order = ['eggnog', 'uhgg', 'cog', 'top_eggnog', 'ko']
thresh_params_order = [
    "corr0-depth250",
    "corr50-depth250",
    "corr100-depth250",
    "corr150-depth250",
    "corr200-depth250",
    "corr250-depth250",
    "corr300-depth250",
    "corr350-depth250",
    "corr400-depth250",
    "corr450-depth250",
    "corr500-depth250",
    "corr550-depth250",
]

xjin_benchmarking = []
# SPGC
missing_tally = 0
empty_tally = 0
for (
    species,
    specgene_params,
    trnsfm_exponent,
    thresh_params,
    max_strain_samples,
    unit,
) in tqdm(list(product(
    species_list,
    ["ref-t25-p95"],
    [1, 3],
    thresh_params_order,
    max_strain_samples_order,
    unit_order,
)), ncols=50):
    trnsfm_exponent_str = int(round(trnsfm_exponent * 10, 0))
    if max_strain_samples == 999:
        ss_params = f'all'
    else:
        ss_params = f'deepest-n{max_strain_samples}'
    path = f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.gene{gene_params}.spgc_specgene-{specgene_params}_ss-xjin-{ss_params}_t-{trnsfm_exponent_str}_thresh-{thresh_params}.{unit}-xjin_strain_summary.tsv"
    if os.path.exists(path):
        mtime = os.path.getmtime(path)
        d = pd.read_table(path)
    else:
        mtime = np.nan
        d = EMPTY_DATA
        missing_tally += 1
    if d.empty:
        d = EMPTY_DATA
        empty_tally += 1

    xjin_benchmarking.append(d.assign(
        species=species,
        tool='spgc',
        unit=unit,
        specgene_params=specgene_params,
        seed=seed,
        max_strain_samples=max_strain_samples,
        thresh_params=thresh_params,
        trnsfm_exponent=trnsfm_exponent,
        path=path,
        run_datetime=pd.to_datetime(mtime, unit='s'),
    ))
    
print(f"For SPGC, {missing_tally} files were missing and {empty_tally} files were empty.")
    
# PanPhlAn and StrainPanDA
missing_tally = 0
empty_tally = 0
for species, tool_string, unit in tqdm(list(product(species_list, ['panphlan', 'spanda-s2'], unit_order)), ncols=50):
    path = f"data/group/{group}/species/sp-{species}/{stemA}.gene{gene_params}.{tool_string}.{unit}-xjin_strain_summary.tsv"
    if os.path.exists(path):
        mtime = os.path.getmtime(path)
        d = pd.read_table(path)
    else:
        mtime = np.nan
        d = EMPTY_DATA
        missing_tally += 1
    if d.empty:
        d = EMPTY_DATA
        empty_tally += 1

    xjin_benchmarking.append(d.assign(
        species=species,
        tool=tool_string.split('-')[0],
        unit=unit,
        path=path,
        run_datetime=pd.to_datetime(mtime, unit='s'),
    ))
print(f"For panphlan and spanda, {missing_tally} files were missing and {empty_tally} files were empty.")

xjin_benchmarking = pd.concat(xjin_benchmarking).assign(
    # Reasonable filters:
    to_drop=lambda x: ( False
        | (x.total_num_reference_genomes > 1)
        | (x.strain != x.dominant_strain)
        # | ((x.num_strain_samples != x.max_strain_samples) & (x.max_strain_samples > 0))
    )
).reset_index(drop=True)

xjin_benchmarking

In [None]:
plt.plot(xjin_benchmarking[lambda x: (~x.to_drop) & (x.max_strain_samples == 999.)].run_datetime.sort_values().values)

In [None]:
xjin_benchmarking[lambda x: ~x.to_drop][['tool', 'unit', 'specgene_params', 'max_strain_samples', 'thresh_params', 'trnsfm_exponent', 'to_drop']].value_counts(dropna=False).unstack('thresh_params')

In [None]:
xjin_benchmarking[lambda x: ~x.to_drop & ~x.run_datetime.isna() & x.tool.isin(['spgc'])].apply(lambda x: x.unique())

In [None]:
xjin_benchmarking[lambda x: ~x.to_drop].dropna(subset=['f1']).apply(lambda x: x.unique()).path[0]

In [None]:
for species in ['100273', '100878', '102386', 'TODO']:
    assert xjin_benchmarking[xjin_benchmarking.species == species].dropna(subset=['correlation_thresh']).empty

xjin_benchmarking[xjin_benchmarking.correlation_thresh.isna()].apply(lambda x: x.unique())

In [None]:
bins = np.linspace(0, 1, num=21)


for trnsfm_exponent in [0.5, 1, 2, 3]:
    d = (
        xjin_benchmarking
        [lambda x: True
         & ~x.to_drop
         & x.unit.isin(['eggnog'])
         & (x.specgene_params == "ref-t25-p95")
         & (x.max_strain_samples == 999)
         & (x.thresh_params == "corr350-depth250")
         & (x.trnsfm_exponent == trnsfm_exponent)
        ]
    )
    plt.hist(d.f1, bins=bins, alpha=0.5, label=trnsfm_exponent)
# plt.hist(d[lambda x: x.strain_depth_sum < 0.5].f1, bins=bins, alpha=0.5)
plt.legend()
None

In [None]:
bins = np.linspace(0.4, 1, num=500)

fig, ax = plt.subplots()

for thresh_params, max_strain_samples, in product(thresh_params_order, max_strain_samples_order):
    d = (
        xjin_benchmarking
        [lambda x: True
         & ~x.to_drop
         & x.unit.isin(['eggnog'])
         & (x.specgene_params == "ref-t25-p95")
         & (x.max_strain_samples == max_strain_samples)
         & (x.thresh_params == thresh_params)
         & (x.trnsfm_exponent == 1.0)
        ]
        .sort_values('f1')
    )
    # plt.hist(d[lambda x: x.strain_depth_sum >= 0.5].f1, bins=bins, cumulative=True, alpha=0.5, lw=5, label="Depth-only (1)", histtype="step")
    ax.plot(d.f1, 1 - d.assign(_one=1)._one.cumsum() / d.assign(_one=1)._one.sum(), label=(thresh_params, max_strain_samples))

# ax.set_xscale('log')
# ax.invert_xaxis()
ax.legend(bbox_to_anchor=(1, 1))
ax.set_xlabel('f1')
ax.set_ylabel('frac strains')
# d = (
#     xjin_benchmarking
#     [lambda x: True
#      & ~x.to_drop
#      & (x.specgene_params == "ref-t25-p95")
#      & (x.max_strain_samples == 10)
#      & (x.thresh_params == "corr0-depth250")
#     ]
# )
# plt.hist(d[lambda x: x.strain_depth_sum >= 0.5].f1, bins=bins, cumulative=True, alpha=0.5, lw=5, label="Depth-only (10)", histtype="step")

# d = (
#     xjin_benchmarking
#     [lambda x: True
#      & ~x.to_drop
#      & (x.specgene_params == "ref-t25-p95")
#      & (x.max_strain_samples == 1)
#      & (x.thresh_params == "corr350-depth250")
#     ]
# )
# plt.hist(d[lambda x: x.strain_depth_sum >= 0.5].f1, bins=bins, cumulative=True, alpha=0.5, lw=5, label="SPGC (1)", histtype="step")
# None

# d = (
#     xjin_benchmarking
#     [lambda x: True
#      & ~x.to_drop
#      & (x.specgene_params == "ref-t25-p95")
#      & (x.max_strain_samples == 10)
#      & (x.thresh_params == "corr350-depth250")
#     ]
# )
# plt.hist(d[lambda x: x.strain_depth_sum >= 0.5].f1, bins=bins, cumulative=True, alpha=0.5, lw=5, label="SPGC (10)", histtype="step")


# # d = (
# #     xjin_benchmarking
# #     [lambda x: True
# #      & ~x.to_drop
# #      & (x.specgene_params == "ref-t25-p95")
# #      & (x.max_strain_samples == -1)
# #      & (x.thresh_params == "corr350-depth250")
# #     ]
# # )
# # plt.hist(d[lambda x: x.strain_depth_sum >= 0.5].f1, bins=bins, alpha=0.5, label="SPGC (all)", histtype="step", cumulative=True)


# plt.legend()

In [None]:
d = (
    xjin_benchmarking
    [lambda x: True
     & ~x.to_drop
     # & x.unit.isin(['uhgg'])
     & x.tool.isin(['spgc'])
     # & (x.specgene_params == "ref-t25-p95")
     # & x.thresh_params.isin(["corr200-depth250"])
     # & (x.trnsfm_exponent == 3.0)
    ]
)
d.apply(lambda x: x.unique())

In [None]:
d = (
    xjin_benchmarking
    [lambda x: True
     & ~x.to_drop
     & x.unit.isin(['eggnog'])
     & x.tool.isin(['spgc'])
     & (x.specgene_params == "ref-t25-p95")
     & x.thresh_params.isin(thresh_params_order)
     & (x.trnsfm_exponent == 3.0)
    ]
)

# NaNs in max_strain_samples, mean that the [int] type order doesn't work.
_max_strain_samples_order = np.array(max_strain_samples_order).astype(float)

fig, axs = plt.subplots(3, figsize=(25, 15))

for y, ax in zip(["precision", "recall", "f1"], axs):
    sns.swarmplot(data=d, x='max_strain_samples', y=y, hue='thresh_params', order=_max_strain_samples_order, hue_order=thresh_params_order, ax=ax, dodge=True)
    ax.legend(bbox_to_anchor=(1, 1))
    ax.set_ylim(-0.05, 1.05)
# plt.yscale('logit')
# sns.swarmplot(data=d, x='max_strain_samples', y='f1', order=max_strain_samples_order)
# sns.box(data=d, x='max_strain_samples', y='f1', order=max_strain_samples_order)

In [None]:
d = (
    xjin_benchmarking
    [lambda x: True
     & ~x.to_drop
     & x.unit.isin(['eggnog'])
     & x.tool.isin(['spgc'])
     & (x.specgene_params == "ref-t25-p95")
     # & x.thresh_params.isin(["corr0-depth250", "corr200-depth250", "corr250-depth250", "corr350-depth250"])
     & (x.trnsfm_exponent == 3.0)
    ]
)

# NaNs in max_strain_samples, mean that the [int] type order doesn't work.
_max_strain_samples_order = np.array(max_strain_samples_order).astype(float)

fig, axs = plt.subplots(3, figsize=(25, 15))

for y, ax in zip(["precision", "recall", "f1"], axs):
    sns.swarmplot(data=d, x='max_strain_samples', y=y, hue='thresh_params', order=_max_strain_samples_order, hue_order=thresh_params_order, ax=ax, dodge=True)
    ax.legend(bbox_to_anchor=(1, 1))
    ax.set_ylim(-0.05, 1.05)
# plt.yscale('logit')
# sns.swarmplot(data=d, x='max_strain_samples', y='f1', order=max_strain_samples_order)
# sns.box(data=d, x='max_strain_samples', y='f1', order=max_strain_samples_order)

In [None]:
d0 = (
    xjin_benchmarking[lambda x: True
        & ~x.to_drop
        & x.unit.isin(['eggnog'])
        & x.tool.isin(['spgc'])
        & (x.specgene_params == "ref-t25-p95")
        & (x.trnsfm_exponent == 1.0)
    ]
    .set_index(['species', 'genome_id', 'thresh_params', 'max_strain_samples'])
    .f1
    .unstack('thresh_params')
    .assign(delta=lambda x: x["corr200-depth250"] - x["corr0-depth250"])
    .delta
)

# NaNs in max_strain_samples, mean that the [int] type order doesn't work.
_max_strain_samples_order = np.array(max_strain_samples_order).astype(float)

fig, ax = plt.subplots(figsize=(25, 5))

sns.swarmplot(data=d0.reset_index(), x='max_strain_samples', y='delta', order=_max_strain_samples_order, ax=ax)
plt.axhline(0, color='k', lw=1, linestyle='--')
# plt.legend(bbox_to_anchor=(1, 1))

d1 = d0.unstack("max_strain_samples")
for max_strain_samples in max_strain_samples_order:
    print(sp.stats.wilcoxon(d1[max_strain_samples].dropna()), d1[max_strain_samples].mean(), d1[max_strain_samples].median())

In [None]:
d0 = (
    xjin_benchmarking[lambda x: True
        & ~x.to_drop
        & x.unit.isin(['eggnog'])
        & x.tool.isin(['spgc'])
        & (x.specgene_params == "ref-t25-p95")
        & (x.trnsfm_exponent == 1.0)
    ]
    .set_index(['species', 'genome_id', 'thresh_params', 'max_strain_samples'])
    .f1
    .unstack(['thresh_params', 'max_strain_samples'])
)
d1 = (
    d0
    .apply(lambda y: y - d0["corr0-depth250", 1])
    .stack(['thresh_params', 'max_strain_samples'])
    .rename("delta")
    .xs("corr200-depth250", level="thresh_params")
)

fig, ax = plt.subplots(figsize=(25, 5))

sns.swarmplot(data=d1.reset_index(), x='max_strain_samples', y='delta', order=_max_strain_samples_order, ax=ax)
plt.axhline(0, color='k', lw=1, linestyle='--')
# # plt.legend(bbox_to_anchor=(1, 1))

d2 = d1.unstack("max_strain_samples")
for max_strain_samples in max_strain_samples_order:
    print(sp.stats.wilcoxon(d2[max_strain_samples].dropna()), d2[max_strain_samples].mean(), d2[max_strain_samples].median())

In [None]:
d0 = (
    xjin_benchmarking[lambda x: True
        & ~x.to_drop
        & x.unit.isin(['eggnog'])
        & x.tool.isin(['spgc'])
        & (x.specgene_params == "ref-t25-p95")
        & (x.trnsfm_exponent == 1.0)
    ]
    .set_index(['species', 'genome_id', 'thresh_params', 'max_strain_samples'])
    .f1
    .unstack(['thresh_params', 'max_strain_samples'])
)
d1 = (
    d0
    .apply(lambda y: y - d0["corr0-depth250", 10])
    .stack(['thresh_params', 'max_strain_samples'])
    .rename("delta")
    .xs("corr350-depth250", level="thresh_params")
)

# NaNs in max_strain_samples, mean that the [int] type order doesn't work.
_max_strain_samples_order = np.array(max_strain_samples_order).astype(float)

fig, ax = plt.subplots(figsize=(25, 5))

sns.swarmplot(data=d1.reset_index(), x='max_strain_samples', y='delta', order=_max_strain_samples_order, ax=ax)
plt.axhline(0, color='k', lw=1, linestyle='--')
# # plt.legend(bbox_to_anchor=(1, 1))

d2 = d1.unstack("max_strain_samples")
for max_strain_samples in max_strain_samples_order:
    print(sp.stats.wilcoxon(d2[max_strain_samples].dropna()), d2[max_strain_samples].mean(), d2[max_strain_samples].median())

In [None]:
unit = 'eggnog'

d = (
    (
        xjin_benchmarking
        [lambda x: True
         & ~x.to_drop
         & x.unit.isin([unit])
         & x.tool.isin(['spgc'])
         & (x.specgene_params == "ref-t25-p95")
         & (x.max_strain_samples == 999)
         & (x.thresh_params == "corr0-depth250")
         & (x.trnsfm_exponent == 1.0) 
        ]
    )
    .merge(
        (
        xjin_benchmarking
        [lambda x: True
         & ~x.to_drop
         & x.unit.isin([unit])
         & x.tool.isin(['spgc'])
         & (x.specgene_params == "ref-t25-p95")
         & (x.max_strain_samples == 999)
         & (x.thresh_params == "corr350-depth250")
         & (x.trnsfm_exponent == 1.0)
        ]
        ),
    on=["genome_id"],
    suffixes=("", "_alt"),
    )
    # .merge(
    #     (
    #     xjin_benchmarking
    #     [lambda x: True
    #      & ~x.to_drop
    #      & (x.specgene_params == "ref-t25-p85")
    #      & (x.max_strain_samples == -1)
    #      & (x.thresh_params == "alpha200")
    #     ]
    #     ),
    # on=["genome_id"],
    # suffixes=("", "_alt2"),
    # )
)

fig, axs = plt.subplots(3, figsize=(6, 15))

for score, ax in zip(['precision', 'recall', 'f1'], axs.flatten()):
    cax = make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05)
    artist = ax.scatter(f'{score}', f'{score}_alt', c="strain_depth_sum", data=d.sort_values("strain_depth_sum"), norm=mpl.colors.LogNorm())
    cbar = fig.colorbar(artist, cax=cax)
    ax.plot([0.01, 0.99], [0.01, 0.99])
    ax.set_title(score)
    ax.set_aspect(1)
    # plt.colorbar()
    # ax.set_xscale('logit')
    # ax.set_yscale('logit')
fig.tight_layout()

In [None]:
d0 = (
    xjin_benchmarking[lambda x: True
        & ~x.to_drop
        & x.unit.isin(['eggnog'])
        & x.tool.isin(['spgc'])
        & (x.specgene_params == "ref-t25-p95")
        & (x.trnsfm_exponent == 3.0)
    ]
    .set_index(['species', 'genome_id', 'thresh_params', 'max_strain_samples'])
    .precision
    .unstack(['thresh_params', 'max_strain_samples'])
)
d1 = (
    d0
    .apply(lambda y: y - d0["corr200-depth250", 999])
    .stack(['thresh_params', 'max_strain_samples'])
    .rename("delta")
    .xs("corr200-depth250", level="thresh_params")
)

fig, ax = plt.subplots(figsize=(25, 5))

# NaNs in max_strain_samples, mean that the [int] type order doesn't work.
_max_strain_samples_order = np.array(max_strain_samples_order).astype(float)


sns.swarmplot(data=d1.reset_index(), x='max_strain_samples', y='delta', order=_max_strain_samples_order, ax=ax)
plt.axhline(0, color='k', lw=1, linestyle='--')
# # plt.legend(bbox_to_anchor=(1, 1))

d2 = d1.unstack("max_strain_samples")
for max_strain_samples in max_strain_samples_order:
    try:
        print(sp.stats.wilcoxon(d2[max_strain_samples].dropna()), d2[max_strain_samples].mean(), d2[max_strain_samples].median())
    except ValueError:
        continue

In [None]:
(
    d1
    .reset_index()
    # [lambda x: x.max_strain_samples == 20]
    .sort_values('delta')
)

In [None]:
d = (
    xjin_benchmarking
    [lambda x: True
     & ~x.to_drop
     & x.unit.isin(['eggnog'])
     & x.tool.isin(['spgc'])
     & (x.specgene_params == "ref-t25-p95")
     & x.thresh_params.isin(["corr350-depth250"])
     & x.trnsfm_exponent.isin([1.0, 3.0])
    ]
)

# NaNs in max_strain_samples, mean that the [int] type order doesn't work.
_max_strain_samples_order = np.array(max_strain_samples_order).astype(float)

fig, axs = plt.subplots(3, figsize=(25, 15))

for y, ax in zip(["precision", "recall", "f1"], axs):
    sns.swarmplot(data=d, x='max_strain_samples', y=y, hue='trnsfm_exponent', order=_max_strain_samples_order, ax=ax, dodge=True)
    ax.legend(bbox_to_anchor=(1, 1))
    ax.set_ylim(-0.05, 1.05)
# plt.yscale('logit')
# sns.swarmplot(data=d, x='max_strain_samples', y='f1', order=max_strain_samples_order)
# sns.box(data=d, x='max_strain_samples', y='f1', order=max_strain_samples_order)

In [None]:
d0 = (
    xjin_benchmarking
    [lambda x: True
     & ~x.to_drop
     & x.unit.isin(['eggnog'])
     & x.tool.isin(['spgc'])
     & (x.specgene_params == "ref-t25-p95")
     # & x.thresh_params.isin(["corr350-depth250"])
    ]
    .set_index(['genome_id', 'correlation_thresh', 'depth_thresh', 'specgene_params', 'max_strain_samples', 'trnsfm_exponent'])
    # .f1
    .xs(
        (0.25, 'ref-t25-p95'),
        level=('depth_thresh', 'specgene_params')
    )
)



d1 = d0.unstack(['correlation_thresh', 'trnsfm_exponent', 'max_strain_samples'])
print(d1.columns.unique().to_frame().apply(lambda x: x.unique()))


score = 'f1'
c = 'strain_depth_sum'
# plt.scatter(d1[(score, 0.35, 1.0, 20)], d1[(score, 0.20, 3.0, 999)], c=d1[('strain_depth_sum', 0.25, 2.0, 20)], norm=mpl.colors.SymLogNorm(linthresh=1))

fig, axs = plt.subplots(2, 3, figsize=(25, 15))

ax = axs.flatten()[0]
ax.set_title('Improves on single-sample depth-only approach.')
artist = ax.scatter(d1[(score, 0.0, 1.0, 1)], d1[(score, 0.2, 3.0, 999)], c=d1[(c, 0.25, 1.0, 20)], norm=mpl.colors.SymLogNorm(linthresh=1))
ax.plot([0, 1], [0, 1])
fig.colorbar(artist, shrink=0.3)
ax.set_aspect(1)

ax = axs.flatten()[1]
ax.set_title('Improves on BEST depth-only performance (10 samples).')
artist = ax.scatter(d1[(score, 0.0, 1.0, 10)], d1[(score, 0.2, 3.0, 999)], c=d1[(c, 0.25, 1.0, 20)], norm=mpl.colors.SymLogNorm(linthresh=1))
ax.plot([0, 1], [0, 1])
fig.colorbar(artist, shrink=0.3)
ax.set_aspect(1)

ax = axs.flatten()[2]
ax.set_title('Similar performance to best untransformed approach at 20 samples')
artist = ax.scatter(d1[(score, 0.35, 1.0, 20)], d1[(score, 0.2, 3.0, 20)], c=d1[(c, 0.25, 1.0, 20)], norm=mpl.colors.SymLogNorm(linthresh=1))
ax.plot([0, 1], [0, 1])
fig.colorbar(artist, shrink=0.3)
ax.set_aspect(1)

ax = axs.flatten()[3]
ax.set_title("Robust to number of samples.")
artist = ax.scatter(d1[(score, 0.2, 3.0, 20)], d1[(score, 0.2, 3.0, 999)], c=d1[(c, 0.25, 1.0, 20)], norm=mpl.colors.SymLogNorm(linthresh=1))
ax.plot([0, 1], [0, 1])
fig.colorbar(artist, shrink=0.3)
ax.set_aspect(1)

ax = axs.flatten()[4]
ax.set_title("Robust to higher cutoff.")
artist = ax.scatter(d1[(score, 0.2, 3.0, 999)], d1[(score, 0.25, 3.0, 999)], c=d1[(c, 0.25, 1.0, 20)], norm=mpl.colors.SymLogNorm(linthresh=1))
ax.plot([0, 1], [0, 1])
fig.colorbar(artist, shrink=0.3)
ax.set_aspect(1)

ax = axs.flatten()[5]
ax.set_title("Robust to lower cutoff.")
artist = ax.scatter(d1[(score, 0.2, 3.0, 999)], d1[(score, 0.15, 3.0, 999)], c=d1[(c, 0.25, 1.0, 20)], norm=mpl.colors.SymLogNorm(linthresh=1))
ax.plot([0, 1], [0, 1])
fig.colorbar(artist, shrink=0.3)
ax.set_aspect(1)

In [None]:
d0 = (
    xjin_benchmarking
    [lambda x: True
     & ~x.to_drop
     & x.unit.isin(['eggnog'])
     & x.tool.isin(['spgc'])
     & (x.specgene_params == "ref-t25-p95")
     # & x.thresh_params.isin(["corr350-depth250"])
    ]
    .set_index(['genome_id', 'correlation_thresh', 'depth_thresh', 'specgene_params', 'max_strain_samples', 'trnsfm_exponent'])
    # .f1
    .xs(
        (0.25, 'ref-t25-p95'),
        level=('depth_thresh', 'specgene_params')
    )
)



d1 = d0.unstack(['correlation_thresh', 'trnsfm_exponent', 'max_strain_samples'])

score = 'f1'
c = 'strain_depth_sum'
x = d1[(score, 0.0, 1.0, 999)]
y = d1[(score, 0.2, 3.0, 999)]
z = d1[(c, 0.2, 3.0, 999)]

print(sp.stats.wilcoxon(x, y))
print((y - x).quantile([0.1, 0.25, 0.75, 0.9]))

In [None]:
plt.scatter(z, (y - x))
plt.xscale('log')
# plt.yscale('symlog', linthresh=0.1)
plt.axhline(0, lw=1, linestyle='--', color='k')
plt.ylim(-1e0, +1e0)

In [None]:
(y - x).mean(), (y - x).median(), ((y - x) > -0).mean()

In [None]:
score_order = ['precision', 'recall', 'f1']

fig, axs = plt.subplots(2, 3, figsize=(15, 10), sharex=True, sharey=True)

for unit, ax_row in zip(unit_order, axs):
    d0 = (
        xjin_benchmarking
        [lambda x: True
         & ~x.to_drop
         & x.unit.isin([unit])
         & x.tool.isin(['spgc'])
         & (x.specgene_params == "ref-t25-p95")
         # & x.thresh_params.isin(["corr350-depth250"])
        ]
        .set_index(['depth_thresh', 'specgene_params', 'genome_id', 'correlation_thresh', 'max_strain_samples', 'trnsfm_exponent'])
        # .f1
        .xs(
            (0.25, 'ref-t25-p95'),
            level=('depth_thresh', 'specgene_params')
        )
    )



    d1 = d0.unstack(['correlation_thresh', 'trnsfm_exponent', 'max_strain_samples'])

    for score, ax in zip(score_order, ax_row):
        c = 'strain_depth_sum'
        x = d1[(score, 0.0, 1.0, 999.)]
        y = d1[(score, 0.1, 3.0, 999.)]
        z = d1[(c, 0.1, 3.0, 999.)]
        artist = ax.scatter(x, y, c=z, norm=mpl.colors.SymLogNorm(linthresh=1, vmin=0, vmax=1e5))
        ax.plot([0, 1], [0, 1])
        ax.set_aspect(1)
    ax_row[0].set_ylabel(unit)
        # fig.colorbar(artist)
        
for score, ax_col in zip(score_order, axs.T):
    ax_col[-1].set_xlabel(score)

In [None]:
unit = 'uhgg'

d0 = (
    xjin_benchmarking
    [lambda x: ~x.to_drop & x.unit.isin([unit])]
)


selector_panphlan = lambda x: (
    (x.tool == 'panphlan') &
    True
)
selector_spanda = lambda x: (
    (x.tool == 'spanda') &
    True
)
selector_depth = lambda x: (
    (x.tool == 'spgc') &
    (x.specgene_params == 'ref-t25-p95') &
    (x.depth_thresh == 0.25) &
    (x.correlation_thresh == 0) &
    (x.max_strain_samples == 999) &
    (x.trnsfm_exponent == 1) &
    True
)
selector_spgc = lambda x: (
    (x.tool == 'spgc') &
    (x.specgene_params == 'ref-t25-p95') &
    (x.depth_thresh == 0.25) &
    (x.correlation_thresh == 0.1) &
    (x.max_strain_samples == 999) &
    (x.trnsfm_exponent == 3.0) &
    True
)
selector_spgc_alt = lambda x: (
    (x.tool == 'spgc') &
    (x.specgene_params == 'ref-t25-p95') &
    (x.depth_thresh == 0.25) &
    (x.correlation_thresh == 0.5) &
    (x.max_strain_samples == 999) &
    (x.trnsfm_exponent == 1.0) &
    True
)


# TODO: Plot SPGC against each of the three competitors.
# Show comparison to all three: precision/recall/f1

_tool_comparison_order = ['panphlan', 'spanda', 'depth_only']
_score_order = ['precision', 'recall', 'f1']

fig, axs = plt.subplots(len(_score_order), len(_tool_comparison_order), figsize=(2.5 * len(_tool_comparison_order) + 2.5, 2.5 * len(_score_order)), sharex=True, sharey=True)
y = 'spgc'
nbins = 15
left_bound = 0.0
bins = [0] + list(np.linspace(left_bound, 1, num=nbins + 1)[1:])
for score, ax_row in zip(_score_order, axs):
    d1 = pd.DataFrame(dict(
        panphlan=d0[selector_panphlan].set_index('genome_id')[score],
        spanda=d0[selector_spanda].set_index('genome_id')[score],
        depth_only=d0[selector_depth].set_index('genome_id')[score],
        spgc=d0[selector_spgc].set_index('genome_id')[score],
        spgc_alt=d0[selector_spgc_alt].set_index('genome_id')[score],
        strain_depth_max=d0[selector_spgc].set_index('genome_id')['strain_depth_max'],
        # strain_depth_sum=d0[selector_spgc].set_index('genome_id')['strain_depth_sum'],
    )).fillna(0)
    print('spgc', len(d1[y]))
    for x, ax in zip(_tool_comparison_order, ax_row):
        *_, cbar_artist = ax.hist2d(x, y, data=d1, bins=bins, cmin=1, norm=mpl.colors.PowerNorm(1/2, vmin=0, vmax=64), cmap='magma_r')
        ax.plot([left_bound, 1], [left_bound, 1], lw=1, linestyle='--', color='k')
        ax.set_aspect(1)
        print(score, x, y, (d1[y] - d1[x]).mean(), sp.stats.wilcoxon(d1[x], d1[y]).pvalue, len(d1[x]))

ax.set_xlim(left_bound, 1)
ax.set_ylim(left_bound, 1)

for score, ax_row in zip(_score_order, axs):
    ax_row[0].set_ylabel(score)
    # *_, artist = ax_row[-1].hist2d(x, y, data=d1.head(0), bins=np.linspace(0, 1, num=21), norm=mpl.colors.PowerNorm(1/2, vmin=0, vmax=30), cmap='Blues')
    # fig.colorbar(artist, cax=ax_row[-1])
    
for tool_comparison, ax_col in zip(_tool_comparison_order, axs.T):
    ax_col[0].set_title(tool_comparison)

# fig.tight_layout()
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.025, 0.67])
fig.colorbar(cbar_artist, cax=cbar_ax, ticks=[0, 1, 2, 4, 8, 16, 32, 64], label="count strains")
# ax.set_xlabel(x)
# ax.set_ylabel(y)

In [None]:
unit = 'uhgg'

d0 = (
    xjin_benchmarking
    [lambda x: ~x.to_drop & x.unit.isin([unit])]
)


selector_panphlan = lambda x: (
    (x.tool == 'panphlan') &
    True
)
selector_spanda = lambda x: (
    (x.tool == 'spanda') &
    True
)
selector_depth = lambda x: (
    (x.tool == 'spgc') &
    (x.specgene_params == 'ref-t25-p95') &
    (x.depth_thresh == 0.25) &
    (x.correlation_thresh == 0) &
    (x.max_strain_samples == 999) &
    (x.trnsfm_exponent == 1) &
    True
)
selector_spgc = lambda x: (
    (x.tool == 'spgc') &
    (x.specgene_params == 'ref-t25-p95') &
    (x.depth_thresh == 0.25) &
    (x.correlation_thresh == 0.1) &
    (x.max_strain_samples == 999) &
    (x.trnsfm_exponent == 3.0) &
    True
)
selector_spgc_alt = lambda x: (
    (x.tool == 'spgc') &
    (x.specgene_params == 'ref-t25-p95') &
    (x.depth_thresh == 0.25) &
    (x.correlation_thresh == 0.5) &
    (x.max_strain_samples == 999) &
    (x.trnsfm_exponent == 1.0) &
    True
)


# TODO: Plot SPGC against each of the three competitors.
# Show comparison to all three: precision/recall/f1

_tool_comparison_order = ['panphlan', 'spanda', 'depth_only']
_score_order = ['precision', 'recall', 'f1']

fig, axs = plt.subplots(len(_score_order), len(_tool_comparison_order), figsize=(2.5 * len(_tool_comparison_order) + 2.5, 2.5 * len(_score_order)), sharex=True, sharey=True)
y = 'spgc'
nbins = 15
left_bound = 0.0
bins = [0] + list(np.linspace(left_bound, 1, num=nbins + 1)[1:])
for score, ax_row in zip(_score_order, axs):
    d1 = pd.DataFrame(dict(
        panphlan=d0[selector_panphlan].set_index('genome_id')[score],
        spanda=d0[selector_spanda].set_index('genome_id')[score],
        depth_only=d0[selector_depth].set_index('genome_id')[score],
        spgc=d0[selector_spgc].set_index('genome_id')[score],
        spgc_alt=d0[selector_spgc_alt].set_index('genome_id')[score],
        strain_depth_max=d0[selector_spgc].set_index('genome_id')['strain_depth_max'],
        # strain_depth_sum=d0[selector_spgc].set_index('genome_id')['strain_depth_sum'],
    )).fillna(0)
    for x, ax in zip(_tool_comparison_order, ax_row):
        *_, cbar_artist = ax.hist2d(x, y, data=d1.loc[[]], bins=bins, cmin=1, norm=mpl.colors.PowerNorm(1/2, vmin=0, vmax=64), cmap='magma_r')
        ax.plot([left_bound, 1], [left_bound, 1], lw=1, linestyle='--', color='k')
        ax.set_aspect(1)
        print(score, x, y, (d1[y] - d1[x]).mean(), sp.stats.wilcoxon(d1[x], d1[y]).pvalue)

ax.set_xlim(left_bound, 1)
ax.set_ylim(left_bound, 1)

for score, ax_row in zip(_score_order, axs):
    ax_row[0].set_ylabel(score)
    # *_, artist = ax_row[-1].hist2d(x, y, data=d1.head(0), bins=np.linspace(0, 1, num=21), norm=mpl.colors.PowerNorm(1/2, vmin=0, vmax=30), cmap='Blues')
    # fig.colorbar(artist, cax=ax_row[-1])
    
for tool_comparison, ax_col in zip(_tool_comparison_order, axs.T):
    ax_col[0].set_title(tool_comparison)

# fig.tight_layout()
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.025, 0.67])
fig.colorbar(cbar_artist, cax=cbar_ax, ticks=[0, 1, 2, 4, 8, 16, 32, 64], label="count strains")
# ax.set_xlabel(x)
# ax.set_ylabel(y)

In [None]:
unit = 'uhgg'

d0 = (
    xjin_benchmarking
    [lambda x: ~x.to_drop & x.unit.isin([unit])]
)


selector_panphlan = lambda x: (
    (x.tool == 'panphlan') &
    True
)
selector_spanda = lambda x: (
    (x.tool == 'spanda') &
    True
)
selector_depth = lambda x: (
    (x.tool == 'spgc') &
    (x.specgene_params == 'ref-t25-p95') &
    (x.depth_thresh == 0.25) &
    (x.correlation_thresh == 0) &
    (x.max_strain_samples == 999) &
    (x.trnsfm_exponent == 1) &
    True
)
selector_spgc = lambda x: (
    (x.tool == 'spgc') &
    (x.specgene_params == 'ref-t25-p95') &
    (x.depth_thresh == 0.25) &
    (x.correlation_thresh == 0.1) &
    (x.max_strain_samples == 999) &
    (x.trnsfm_exponent == 3.0) &
    True
)
selector_spgc_alt = lambda x: (
    (x.tool == 'spgc') &
    (x.specgene_params == 'ref-t25-p95') &
    (x.depth_thresh == 0.25) &
    (x.correlation_thresh == 0.5) &
    (x.max_strain_samples == 999) &
    (x.trnsfm_exponent == 1.0) &
    True
)


# TODO: Plot SPGC against each of the three competitors.
# Show comparison to all three: precision/recall/f1

_tool_order = ['panphlan', 'spanda', 'depth_only', 'spgc']
tool_palette = lib.plot.construct_ordered_palette(_tool_order, cm='rainbow')
_score_order = ['precision', 'recall', 'f1']

fig, axs = plt.subplots(len(_score_order), len(_tool_order) + 1, figsize=(7 * len(_tool_order) + 1, 5 * len(_score_order)), sharex=True, sharey=True)
for score, ax_row in zip(_score_order, axs):
    d1 = pd.DataFrame(dict(
        panphlan=d0[selector_panphlan].set_index('genome_id')[score],
        spanda=d0[selector_spanda].set_index('genome_id')[score],
        depth_only=d0[selector_depth].set_index('genome_id')[score],
        spgc=d0[selector_spgc].set_index('genome_id')[score],
        spgc_alt=d0[selector_spgc_alt].set_index('genome_id')[score],
        strain_depth_max=d0[selector_spgc].set_index('genome_id').species.map(max_species_depth),
        # strain_depth_sum=d0[selector_spgc].set_index('genome_id')['strain_depth_sum'],
    )).fillna(0)
    trend_ax = ax_row[-1]
    for tool, ax in zip(_tool_order, ax_row):
        ax.scatter('strain_depth_max', tool, data=d1, s=20)
        smoothed = d1[['strain_depth_max', tool]].sort_values('strain_depth_max').rolling(window=20, min_periods=1, center=True).mean()
        # smoothed_x = np.logspace(-1, 2, num=200)
        # smoothed_y = sm.nonparametric.lowess(d1[tool].values, d1['strain_depth_max'].values, xvals=smoothed_x, it=5, frac=1/4, return_sorted=False)
        ax.plot('strain_depth_max', tool, data=smoothed[lambda x: x.strain_depth_max > 0.1], color=tool_palette[tool], lw=4, zorder=0)
        trend_ax.plot('strain_depth_max', tool, data=smoothed[lambda x: x.strain_depth_max > 0.1], color=tool_palette[tool], lw=3, zorder=0, alpha=0.9, label=tool)
# ax.set_yscale('symlog', linthresh=1e-2)
# ax.invert_yaxis()
        
        
for ax_row in axs:
    for ax in ax_row:
        ax.axvline(1.0, lw=1, linestyle='--', color='k')
        ax.axhline(0.9, lw=1, linestyle=':', color='k')
        # ax.axhline(1 - 0.5, lw=1, linestyle='--', color='k')
# trend_ax.legend()

ax.set_xscale('symlog', linthresh=1e-1)
ax.set_xlim(-1e-2)
# ax.set_yscale('logit')
# ax.set_ylim(0.5)

for score, ax_row in zip(_score_order, axs):
    ax_row[0].set_ylabel(score)
ax_row[0].set_xlabel('strain_depth_max')
    
for tool, ax_col in zip(_tool_order, axs.T):
    ax_col[0].set_title(tool)

In [None]:
unit = 'eggnog'

d0 = (
    xjin_benchmarking
    [lambda x: ~x.to_drop & x.unit.isin([unit])]
)

tool = 'spgc'
selector = lambda x: (
    (x.tool == tool) &
    (x.specgene_params == 'ref-t25-p95') &
    (x.depth_thresh == 0.25) &
    (x.max_strain_samples == 999) &
    (x.trnsfm_exponent == 3.0) &
    True
)

# TODO: Plot SPGC against each of the three competitors.
# Show comparison to all three: precision/recall/f1

correlation_thresh_order = xjin_benchmarking.correlation_thresh.dropna().sort_values().unique()  # [0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45]
correlation_thresh_palette = lib.plot.construct_ordered_palette(correlation_thresh_order, cm='rainbow')
_score_order = ['precision', 'recall', 'f1']
genome_id_list = d0.genome_id.unique()

fig, axs = plt.subplots(len(_score_order), len(correlation_thresh_order) + 1, figsize=(7 * len(correlation_thresh_order) + 1, 5 * len(_score_order)), sharex=True, sharey=True)
for score, ax_row in zip(_score_order, axs):
    trend_ax = ax_row[-1]
    for correlation_thresh, ax in zip(correlation_thresh_order, ax_row):
        d1 = d0[selector][lambda x: x.correlation_thresh == correlation_thresh].set_index('genome_id')
        d2 = pd.DataFrame(dict(
            score=d1[score].reindex(genome_id_list).fillna(0),
            strain_depth_max=d1.species.map(max_species_depth).reindex(genome_id_list).fillna(0),
        )).fillna(0)
        ax.scatter('strain_depth_max', 'score', data=d2, s=20)
        smoothed = d2[['strain_depth_max', 'score']].sort_values('strain_depth_max').rolling(window=20, min_periods=1, center=True).mean()
        ax.plot('strain_depth_max', 'score', data=smoothed[lambda x: x.strain_depth_max > 0.1], color=correlation_thresh_palette[correlation_thresh], lw=4, zorder=0)
        trend_ax.plot('strain_depth_max', 'score', data=smoothed[lambda x: x.strain_depth_max > 0.1], color=correlation_thresh_palette[correlation_thresh], lw=4, zorder=0, alpha=0.9, label=correlation_thresh)
# ax.set_yscale('symlog', linthresh=1e-2)
# ax.invert_yaxis()
# ax.set_yscale('logit')
        
        
for ax_row in axs:
    for ax in ax_row:
        ax.axvline(1.0, lw=1, linestyle='--', color='k')
        ax.axhline(0.9, lw=1, linestyle=':', color='k')
        # ax.axhline(1 - 0.5, lw=1, linestyle='--', color='k')
# trend_ax.legend()

ax.set_xscale('symlog', linthresh=1e-1)
ax.set_xlim(-1e-2)

for score, ax_row in zip(_score_order, axs):
    ax_row[0].set_ylabel(score)
ax_row[0].set_xlabel('strain_depth_max')
    
for correlation_thresh, ax_col in zip(correlation_thresh_order, axs.T):
    ax_col[0].set_title(correlation_thresh)

In [None]:
unit = 'eggnog'
score = 'f1'
d0 = (
    xjin_benchmarking
    [lambda x: ~x.to_drop & x.unit.isin([unit])]
)

tool = 'spgc'
selector = lambda x: (
    (x.tool == tool) &
    (x.specgene_params == 'ref-t25-p95') &
    (x.depth_thresh == 0.25) &
    (x.max_strain_samples == 999) &
    True
)

# TODO: Plot SPGC against each of the three competitors.
# Show comparison to all three: precision/recall/f1

correlation_thresh_order = xjin_benchmarking.correlation_thresh.dropna().sort_values().unique()  # [0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45]
correlation_thresh_palette = lib.plot.construct_ordered_palette(correlation_thresh_order, cm='rainbow')
_trnsfm_exponent_order = [1.0, 3.0]

fig, axs = plt.subplots(len(_trnsfm_exponent_order), len(correlation_thresh_order) + 1, figsize=(7 * len(correlation_thresh_order) + 1, 5 * len(_trnsfm_exponent_order)), sharex=True, sharey=True)
for trnsfm_exponent, ax_row in zip(_trnsfm_exponent_order, axs):
    trend_ax = ax_row[-1]
    for correlation_thresh, ax in zip(correlation_thresh_order, ax_row):
        d1 = d0[selector][lambda x: (x.correlation_thresh == correlation_thresh) & (x.trnsfm_exponent == trnsfm_exponent)].set_index('genome_id')
        d2 = pd.DataFrame(dict(
            score=d1[score].reindex(genome_id_list).fillna(0),
            strain_depth_max=d1.species.map(max_species_depth).reindex(genome_id_list).fillna(0),
        )).fillna(0)
        ax.scatter('strain_depth_max', 'score', data=d2, s=20)
        smoothed = d2[['strain_depth_max', 'score']].sort_values('strain_depth_max').rolling(window=20, min_periods=1, center=True).mean()
        ax.plot('strain_depth_max', 'score', data=smoothed[lambda x: x.strain_depth_max > 0.1], color=correlation_thresh_palette[correlation_thresh], lw=4, zorder=0)
        trend_ax.plot('strain_depth_max', 'score', data=smoothed[lambda x: x.strain_depth_max > 0.1], color=correlation_thresh_palette[correlation_thresh], lw=4, zorder=0, alpha=0.9, label=correlation_thresh)
# ax.set_yscale('symlog', linthresh=1e-2)
# ax.invert_yaxis()
# ax.set_yscale('logit')
        
        
for ax_row in axs:
    for ax in ax_row:
        ax.axvline(1.0, lw=1, linestyle=':', color='k')
        # ax.axvline(2.0, lw=1, linestyle=':', color='k')
        # ax.axhline(0.95, lw=1, linestyle=':', color='k')
        ax.axhline(0.9, lw=1, linestyle=':', color='k')
        # ax.axhline(0.85, lw=1, linestyle=':', color='k')
        # ax.axhline(0.85, lw=1, linestyle=':', color='k')
        # ax.axhline(0.75, lw=1, linestyle=':', color='k')
        # ax.axhline(0.75, lw=1, linestyle=':', color='k')
        # ax.axhline(1 - 0.5, lw=1, linestyle='--', color='k')
# trend_ax.legend()

ax.set_xscale('symlog', linthresh=1e-1)
ax.set_xlim(-1e-2)

for trnsfm_exponent, ax_row in zip(_trnsfm_exponent_order, axs):
    ax_row[0].set_ylabel(trnsfm_exponent)
ax_row[0].set_xlabel('strain_depth_max')
    
for correlation_thresh, ax_col in zip(correlation_thresh_order, axs.T):
    ax_col[0].set_title(correlation_thresh)

In [None]:
d = (
    xjin_benchmarking
    [lambda x: ~x.to_drop]
    .set_index(['unit', 'genome_id', 'correlation_thresh', 'depth_thresh', 'max_strain_samples', 'tool', 'seed', 'trnsfm_exponent'])
    .f1
    .unstack('unit')
)
pg = sns.pairplot(d, kind='hist', height=4, vars=unit_order)
pg.map_offdiag(lambda x, y, color=None, label=None: plt.gca().plot([0, 1], [0, 1], color=color, label=label))
pg.map_lower(lambda x, y, color=None, label=None: plt.gca().plot([0, 1], [0, 1], color=color, label=label))

In [None]:
d = (
    xjin_benchmarking
    [lambda x:
     ~x.to_drop
    ]
    .set_index(['unit', 'genome_id', 'correlation_thresh', 'depth_thresh', 'max_strain_samples', 'tool', 'seed', 'trnsfm_exponent', 'specgene_params'])
    .xs(('eggnog', 'ref-t25-p95'), level=('unit', 'specgene_params'))
    [['precision', 'recall', 'f1', 'species_gene_frac', 'strain_depth_sum', 'strain_depth_max']]
    .assign(log10_strain_depth_max=lambda x: np.log10(x.strain_depth_max), log10_strain_depth_sum=lambda x: np.log10(x.strain_depth_sum))
)

pg = sns.pairplot(d, kind='hist', height=4, vars=['precision', 'recall', 'f1', 'species_gene_frac', 'log10_strain_depth_sum', 'log10_strain_depth_max'])
# pg.map_upper(lambda x, y, color=None, label=None: plt.gca().plot([0, 1], [0, 1], color=color, label=label))
# pg.map_lower(lambda x, y, color=None, label=None: plt.gca().plot([0, 1], [0, 1], color=color, label=label))

In [None]:
d = (
    xjin_benchmarking
    [lambda x: ~x.to_drop]
    .set_index(['unit', 'genome_id', 'correlation_thresh', 'depth_thresh', 'max_strain_samples', 'tool', 'seed', 'trnsfm_exponent', 'specgene_params'])
    .xs(('spgc', 'eggnog'), level=('tool', 'unit'))
    [['species_gene_frac', 'precision', 'recall', 'f1']]
)
pg = sns.pairplot(d.xs('ref-t25-p95', level='specgene_params'), kind='hist', height=4, vars=['precision', 'recall', 'f1', 'species_gene_frac'])

In [None]:
d = xjin_benchmarking[lambda x: ~x.to_drop][lambda x: (x.unit == 'eggnog')].groupby(['trnsfm_exponent', 'correlation_thresh', 'max_strain_samples']).f1.quantile([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]).rename_axis(index={None: 'quantile'}).unstack(['max_strain_samples', 'quantile'])
fig, ax = plt.subplots(figsize=(10, 10))
sns.heatmap(1 - d[999], ax=ax, norm=mpl.colors.PowerNorm(1/3), cmap='viridis_r')

In [None]:
d = xjin_benchmarking[lambda x: ~x.to_drop][lambda x: (x.unit == 'eggnog')].groupby(['max_strain_samples', 'trnsfm_exponent', 'correlation_thresh']).f1.quantile([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]).rename_axis(index={None: 'quantile'}).unstack(['quantile'])
fig, ax = plt.subplots(figsize=(10, 10))
sns.heatmap(1 - d, ax=ax, norm=mpl.colors.PowerNorm(1/3), cmap='viridis_r')

In [None]:
d = xjin_benchmarking[lambda x: ~x.to_drop][lambda x: (x.unit == 'eggnog')].groupby(['max_strain_samples', 'trnsfm_exponent', 'correlation_thresh']).f1.quantile([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]).rename_axis(index={None: 'quantile'}).unstack(['quantile'])
fig, ax = plt.subplots(figsize=(10, 10))
sns.heatmap(d.xs(999, level='max_strain_samples') - d.max(), ax=ax, cmap='viridis_r')

In [None]:
unit = 'eggnog'
best_score = xjin_benchmarking[lambda x: ~x.to_drop][lambda x: (x.unit == unit)].groupby(['genome_id']).f1.max()

score_dropoff = (
    xjin_benchmarking
    [lambda x: ~x.to_drop & (x.unit == unit)]
    .set_index(['tool', 'max_strain_samples', 'trnsfm_exponent', 'depth_thresh', 'correlation_thresh', 'genome_id'])
    .f1.unstack('genome_id')
    .fillna(0)
    - best_score
)
fig, ax = plt.subplots(figsize=(10, 40))
sns.heatmap(-(score_dropoff).quantile(np.linspace(0, 1), axis=1).T, ax=ax, cmap='viridis', xticklabels=0, norm=mpl.colors.PowerNorm(1/1))

In [None]:
# TODO: Check across all xjin strains that the accuracy doesn't decrease when including non-xjin samples.
# TODO: Check that species-gene-fraction is a good quality indicator (and maybe total genome size, too?) across the ground-truthed xjin strains.