### Preamble

#### Project Template

In [None]:
%load_ext autoreload

In [None]:
import os as _os
_os.chdir(_os.environ['PROJECT_ROOT'])
_os.path.realpath(_os.path.curdir)

#### Imports

In [None]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
from lib.pandas_util import idxwhere, aligned_index, align_indexes, invert_mapping
import matplotlib as mpl
import lib.plot
import statsmodels as sm
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm
import subprocess
from tempfile import mkstemp
import time
import subprocess
from itertools import chain
import os
from itertools import product
from mpl_toolkits.axes_grid1 import make_axes_locatable
from datetime import datetime

In [None]:
import sfacts as sf

In [None]:
import lib.thisproject.data

#### Set Style

In [None]:
sns.set_context('talk')
plt.rcParams['figure.dpi'] = 50

## Set Parameters / Load Data

In [None]:
# Fixed params

group='XJIN_BENCHMARK'
stemA='r.proc'
centroid=75
stemB = 'sfacts-fit'
gene_params = f"99-v22-agg{centroid}"
# depth_thresh="250"
# specgene_params='specgene-ref-t25-p90'

In [None]:
ref_strains = pd.read_table('meta/genome.tsv', index_col='genome_id')[lambda x: ~x.genome_path.isna()]
species_list = ref_strains.species_id.unique()

In [None]:
EMPTY_DATA = pd.Series(dict(
            strain=None,
            precision=np.nan,
            recall=np.nan,
            f1=np.nan,
            genome_id=None,
            correlation_thresh=np.nan,
            depth_thresh=np.nan,
            dominant_strain=None,
            total_num_reference_genomes=-1,
            total_num_xjin_strains=-1,
            num_species_free_samples=-1,
            num_strain_samples=-1,
            strain_depth_sum=np.nan,
            strain_depth_max=np.nan,
            strain_depth_std=np.nan,
        )).to_frame().T

In [None]:
pd.to_datetime(os.path.getmtime('data/group/XJIN_BENCHMARK/species/sp-102492/r.proc.gtpro.sfacts-fit.gene99-v22-agg75.spgc_specgene-ref-t25-p95_ss-xjin-all_t-20_thresh-corr350-depth250.xjin_strain_summary.tsv'), unit='s')

In [None]:
# TODO: Make these loop variables
seed = 0
max_strain_samples_order = [1, 5, 10, 20, -1]

xjin_benchmarking = []
# SPGC
for (
    species,
    specgene_params,
    trnsfm_exponent,
    thresh_params,
    max_strain_samples,
    unit,
) in tqdm(list(product(
    species_list,
    ["ref-t25-p95"],
    [0.5, 1, 2, 3],
    ["corr0-depth250", "corr150-depth250", "corr200-depth250", "corr250-depth250", "corr350-depth250"],
    max_strain_samples_order,
    ['uhgg', 'ko', 'cog'],
)), ncols=50):
    if max_strain_samples == -1:
        ss_params = "all"
        seed_list = [0]
    else:
        ss_params = f"deepest-n{max_strain_samples}"
    trnsfm_exponent_str = int(round(trnsfm_exponent * 10, 0))
    path = f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.gene{gene_params}.spgc_specgene-{specgene_params}_ss-xjin-{ss_params}_t-{trnsfm_exponent_str}_thresh-{thresh_params}.{unit}-xjin_strain_summary.tsv"
    if os.path.exists(path):
        mtime = os.path.getmtime(path)
        d = pd.read_table(path)
    else:
        mtime = np.nan
        d = EMPTY_DATA
    if d.empty:
        d = EMPTY_DATA

    xjin_benchmarking.append(d.assign(
        species=species,
        tool='spgc',
        unit=unit,
        specgene_params=specgene_params,
        seed=seed,
        max_strain_samples=max_strain_samples,
        thresh_params=thresh_params,
        trnsfm_exponent=trnsfm_exponent,
        path=path,
        run_datetime=pd.to_datetime(mtime, unit='s'),
    ))
    
# PanPhlAn and StrainPanDA
for species, tool_string, unit in tqdm(list(product(species_list, ['panphlan', 'spanda-s2'], ['uhgg', 'ko', 'cog'])), ncols=50):
    path = f"data/group/{group}/species/sp-{species}/{stemA}.gene{gene_params}.{tool_string}.{unit}-xjin_strain_summary.tsv"
    if os.path.exists(path):
        mtime = os.path.getmtime(path)
        d = pd.read_table(path)
    else:
        mtime = np.nan
        d = EMPTY_DATA
    if d.empty:
        d = EMPTY_DATA

    xjin_benchmarking.append(d.assign(
        species=species,
        tool=tool_string.split('-')[0],
        unit=unit,
        path=path,
        run_datetime=pd.to_datetime(mtime, unit='s'),
    ))

xjin_benchmarking = pd.concat(xjin_benchmarking).assign(
    # Reasonable filters:
    to_drop=lambda x: ( False
        | (x.total_num_reference_genomes > 1)
        | (x.strain != x.dominant_strain)
        # | ((x.num_strain_samples != x.max_strain_samples) & (x.max_strain_samples > 0))
    )
).reset_index(drop=True)

xjin_benchmarking

In [None]:
xjin_benchmarking[lambda x: ~x.to_drop & ~x.run_datetime.isna() & x.tool.isin(['spgc'])].apply(lambda x: x.unique())

In [None]:
plt.plot(xjin_benchmarking.run_datetime.sort_values().values)

In [None]:
xjin_benchmarking[lambda x: ~x.to_drop].dropna(subset=['f1']).apply(lambda x: x.unique()).path[0]

In [None]:
for species in ['100273', '100878', '102386', 'TODO']:
    assert xjin_benchmarking[xjin_benchmarking.species == species].dropna(subset=['correlation_thresh']).empty

xjin_benchmarking[xjin_benchmarking.correlation_thresh.isna()].apply(lambda x: x.unique())

In [None]:
bins = np.linspace(0, 1, num=21)


for trnsfm_exponent in [0.5, 1, 2, 3]:
    d = (
        xjin_benchmarking
        [lambda x: True
         & ~x.to_drop
         & x.unit.isin(['uhgg'])
         & (x.specgene_params == "ref-t25-p95")
         & (x.max_strain_samples == -1)
         & (x.thresh_params == "corr350-depth250")
         & (x.trnsfm_exponent == trnsfm_exponent)
        ]
    )
    plt.hist(d.f1, bins=bins, alpha=0.5, label=trnsfm_exponent)
# plt.hist(d[lambda x: x.strain_depth_sum < 0.5].f1, bins=bins, alpha=0.5)
plt.legend()
None

In [None]:
bins = np.linspace(0.4, 1, num=500)

fig, ax = plt.subplots()

for thresh_params, max_strain_samples, in product(["corr0-depth250", "corr350-depth250"], [1, 10]):
    d = (
        xjin_benchmarking
        [lambda x: True
         & ~x.to_drop
         & x.unit.isin(['uhgg'])
         & (x.specgene_params == "ref-t25-p95")
         & (x.max_strain_samples == max_strain_samples)
         & (x.thresh_params == thresh_params)
         & (x.trnsfm_exponent == 1.0)
        ]
        .sort_values('f1')
    )
    # plt.hist(d[lambda x: x.strain_depth_sum >= 0.5].f1, bins=bins, cumulative=True, alpha=0.5, lw=5, label="Depth-only (1)", histtype="step")
    ax.plot(d.f1, 1 - d.assign(_one=1)._one.cumsum() / d.assign(_one=1)._one.sum(), label=(thresh_params, max_strain_samples))

# ax.set_xscale('log')
# ax.invert_xaxis()
ax.legend(bbox_to_anchor=(1, 1))
ax.set_xlabel('f1')
ax.set_ylabel('frac strains')
# d = (
#     xjin_benchmarking
#     [lambda x: True
#      & ~x.to_drop
#      & (x.specgene_params == "ref-t25-p95")
#      & (x.max_strain_samples == 10)
#      & (x.thresh_params == "corr0-depth250")
#     ]
# )
# plt.hist(d[lambda x: x.strain_depth_sum >= 0.5].f1, bins=bins, cumulative=True, alpha=0.5, lw=5, label="Depth-only (10)", histtype="step")

# d = (
#     xjin_benchmarking
#     [lambda x: True
#      & ~x.to_drop
#      & (x.specgene_params == "ref-t25-p95")
#      & (x.max_strain_samples == 1)
#      & (x.thresh_params == "corr350-depth250")
#     ]
# )
# plt.hist(d[lambda x: x.strain_depth_sum >= 0.5].f1, bins=bins, cumulative=True, alpha=0.5, lw=5, label="SPGC (1)", histtype="step")
# None

# d = (
#     xjin_benchmarking
#     [lambda x: True
#      & ~x.to_drop
#      & (x.specgene_params == "ref-t25-p95")
#      & (x.max_strain_samples == 10)
#      & (x.thresh_params == "corr350-depth250")
#     ]
# )
# plt.hist(d[lambda x: x.strain_depth_sum >= 0.5].f1, bins=bins, cumulative=True, alpha=0.5, lw=5, label="SPGC (10)", histtype="step")


# # d = (
# #     xjin_benchmarking
# #     [lambda x: True
# #      & ~x.to_drop
# #      & (x.specgene_params == "ref-t25-p95")
# #      & (x.max_strain_samples == -1)
# #      & (x.thresh_params == "corr350-depth250")
# #     ]
# # )
# # plt.hist(d[lambda x: x.strain_depth_sum >= 0.5].f1, bins=bins, alpha=0.5, label="SPGC (all)", histtype="step", cumulative=True)


# plt.legend()

In [None]:
d = (
    xjin_benchmarking
    [lambda x: True
     & ~x.to_drop
     & x.unit.isin(['uhgg'])
     & x.tool.isin(['spgc'])
     & (x.specgene_params == "ref-t25-p95")
     & x.thresh_params.isin(["corr0-depth250", "corr200-depth250", "corr250-depth250", "corr350-depth250"])
     & (x.trnsfm_exponent == 1.0)
    ]
)
d.apply(lambda x: x.unique())

In [None]:
d = (
    xjin_benchmarking
    [lambda x: True
     & ~x.to_drop
     & x.unit.isin(['uhgg'])
     & x.tool.isin(['spgc'])
     & (x.specgene_params == "ref-t25-p95")
     & x.thresh_params.isin(["corr0-depth250", "corr200-depth250", "corr250-depth250", "corr350-depth250"])
     & (x.trnsfm_exponent == 1.0)
    ]
)

# NaNs in max_strain_samples, mean that the [int] type order doesn't work.
_max_strain_samples_order = np.array(max_strain_samples_order).astype(float)

fig, axs = plt.subplots(3, figsize=(25, 15))

for y, ax in zip(["precision", "recall", "f1"], axs):
    sns.swarmplot(data=d, x='max_strain_samples', y=y, hue='thresh_params', order=_max_strain_samples, ax=ax, dodge=True)
    ax.legend(bbox_to_anchor=(1, 1))
    ax.set_ylim(-0.05, 1.05)
# plt.yscale('logit')
# sns.swarmplot(data=d, x='max_strain_samples', y='f1', order=max_strain_samples_order)
# sns.box(data=d, x='max_strain_samples', y='f1', order=max_strain_samples_order)

In [None]:
d = (
    xjin_benchmarking
    [lambda x: True
     & ~x.to_drop
     & x.unit.isin(['ko'])
     & x.tool.isin(['spgc'])
     & (x.specgene_params == "ref-t25-p95")
     & x.thresh_params.isin(["corr0-depth250", "corr200-depth250", "corr250-depth250", "corr350-depth250"])
     & (x.trnsfm_exponent == 1.0)
    ]
)

# NaNs in max_strain_samples, mean that the [int] type order doesn't work.
_max_strain_samples_order = np.array(max_strain_samples_order).astype(float)

fig, axs = plt.subplots(3, figsize=(25, 15))

for y, ax in zip(["precision", "recall", "f1"], axs):
    sns.swarmplot(data=d, x='max_strain_samples', y=y, hue='thresh_params', order=_max_strain_samples_order, ax=ax, dodge=True)
    ax.legend(bbox_to_anchor=(1, 1))
    ax.set_ylim(-0.05, 1.05)
# plt.yscale('logit')
# sns.swarmplot(data=d, x='max_strain_samples', y='f1', order=max_strain_samples_order)
# sns.box(data=d, x='max_strain_samples', y='f1', order=max_strain_samples_order)

In [None]:
d0 = (
    xjin_benchmarking[lambda x: True
        & ~x.to_drop
        & x.unit.isin(['uhgg'])
        & x.tool.isin(['spgc'])
        & (x.specgene_params == "ref-t25-p95")
        & (x.trnsfm_exponent == 1.0)
    ]
    .set_index(['species', 'genome_id', 'thresh_params', 'max_strain_samples'])
    .f1
    .unstack('thresh_params')
    .assign(delta=lambda x: x["corr200-depth250"] - x["corr0-depth250"])
    .delta
)

# NaNs in max_strain_samples, mean that the [int] type order doesn't work.
_max_strain_samples_order = np.array(max_strain_samples_order).astype(float)

fig, ax = plt.subplots(figsize=(25, 5))

sns.swarmplot(data=d0.reset_index(), x='max_strain_samples', y='delta', order=_max_strain_samples_order, ax=ax)
plt.axhline(0, color='k', lw=1, linestyle='--')
# plt.legend(bbox_to_anchor=(1, 1))

d1 = d0.unstack("max_strain_samples")
for max_strain_samples in max_strain_samples_order:
    print(sp.stats.wilcoxon(d1[max_strain_samples].dropna()), d1[max_strain_samples].mean(), d1[max_strain_samples].median())

In [None]:
d0 = (
    xjin_benchmarking[lambda x: True
        & ~x.to_drop
        & x.unit.isin(['ko'])
        & x.tool.isin(['spgc'])
        & (x.specgene_params == "ref-t25-p95")
        & (x.trnsfm_exponent == 1.0)
    ]
    .set_index(['species', 'genome_id', 'thresh_params', 'max_strain_samples'])
    .f1
    .unstack(['thresh_params', 'max_strain_samples'])
)
d1 = (
    d0
    .apply(lambda y: y - d0["corr0-depth250", 1])
    .stack(['thresh_params', 'max_strain_samples'])
    .rename("delta")
    .xs("corr350-depth250", level="thresh_params")
)

fig, ax = plt.subplots(figsize=(25, 5))

sns.swarmplot(data=d1.reset_index(), x='max_strain_samples', y='delta', order=max_strain_samples_order, ax=ax)
plt.axhline(0, color='k', lw=1, linestyle='--')
# # plt.legend(bbox_to_anchor=(1, 1))

d2 = d1.unstack("max_strain_samples")
for max_strain_samples in max_strain_samples_order:
    print(sp.stats.wilcoxon(d2[max_strain_samples].dropna()), d2[max_strain_samples].mean(), d2[max_strain_samples].median())

In [None]:
d0 = (
    xjin_benchmarking[lambda x: True
        & ~x.to_drop
        & x.unit.isin(['uhgg'])
        & x.tool.isin(['spgc'])
        & (x.specgene_params == "ref-t25-p95")
        & (x.trnsfm_exponent == 1.0)
    ]
    .set_index(['species', 'genome_id', 'thresh_params', 'max_strain_samples'])
    .f1
    .unstack(['thresh_params', 'max_strain_samples'])
)
d1 = (
    d0
    .apply(lambda y: y - d0["corr0-depth250", 5])
    .stack(['thresh_params', 'max_strain_samples'])
    .rename("delta")
    .xs("corr350-depth250", level="thresh_params")
)

# NaNs in max_strain_samples, mean that the [int] type order doesn't work.
_max_strain_samples_order = np.array(max_strain_samples_order).astype(float)

fig, ax = plt.subplots(figsize=(25, 5))

sns.swarmplot(data=d1.reset_index(), x='max_strain_samples', y='delta', order=_max_strain_samples_order, ax=ax)
plt.axhline(0, color='k', lw=1, linestyle='--')
# # plt.legend(bbox_to_anchor=(1, 1))

d2 = d1.unstack("max_strain_samples")
for max_strain_samples in max_strain_samples_order:
    print(sp.stats.wilcoxon(d2[max_strain_samples].dropna()), d2[max_strain_samples].mean(), d2[max_strain_samples].median())

In [None]:
d = (
    xjin_benchmarking
    [lambda x: True
     & ~x.to_drop
     & (x.specgene_params == "ref-t25-p95")
     & (x.max_strain_samples == -1)
     & (x.thresh_params == "corr350-depth250")
    ]
)

d.sort_values("f1", ascending=False).tail(10)
# d[d.species == "102395"]

In [None]:
unit = 'uhgg'

d = (
    (
        xjin_benchmarking
        [lambda x: True
         & ~x.to_drop
         & x.unit.isin([unit])
         & x.tool.isin(['spgc'])
         & (x.specgene_params == "ref-t25-p95")
         & (x.max_strain_samples == 5)
         & (x.thresh_params == "corr0-depth250")
         & (x.trnsfm_exponent == 1.0) 
        ]
    )
    .merge(
        (
        xjin_benchmarking
        [lambda x: True
         & ~x.to_drop
         & x.unit.isin([unit])
         & x.tool.isin(['spgc'])
         & (x.specgene_params == "ref-t25-p95")
         & (x.max_strain_samples == 10)
         & (x.thresh_params == "corr350-depth250")
         & (x.trnsfm_exponent == 1.0)
        ]
        ),
    on=["genome_id"],
    suffixes=("", "_alt"),
    )
    # .merge(
    #     (
    #     xjin_benchmarking
    #     [lambda x: True
    #      & ~x.to_drop
    #      & (x.specgene_params == "ref-t25-p85")
    #      & (x.max_strain_samples == -1)
    #      & (x.thresh_params == "alpha200")
    #     ]
    #     ),
    # on=["genome_id"],
    # suffixes=("", "_alt2"),
    # )
)

fig, axs = plt.subplots(3, figsize=(6, 15))

for score, ax in zip(['precision', 'recall', 'f1'], axs.flatten()):
    cax = make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05)
    artist = ax.scatter(f'{score}', f'{score}_alt', c="strain_depth_sum", data=d.sort_values("strain_depth_sum"), norm=mpl.colors.LogNorm())
    cbar = fig.colorbar(artist, cax=cax)
    ax.plot([0.01, 0.99], [0.01, 0.99])
    ax.set_title(score)
    ax.set_aspect(1)
    # plt.colorbar()
    # ax.set_xscale('logit')
    # ax.set_yscale('logit')
fig.tight_layout()

In [None]:
d0 = (
    xjin_benchmarking[lambda x: True
        & ~x.to_drop
        & x.unit.isin(['uhgg'])
        & x.tool.isin(['spgc'])
        & (x.specgene_params == "ref-t25-p95")
        & (x.trnsfm_exponent == 1.0)
    ]
    .set_index(['species', 'genome_id', 'thresh_params', 'max_strain_samples'])
    .precision
    .unstack(['thresh_params', 'max_strain_samples'])
)
d1 = (
    d0
    .apply(lambda y: y - d0["corr350-depth250", -1])
    .stack(['thresh_params', 'max_strain_samples'])
    .rename("delta")
    .xs("corr200-depth250", level="thresh_params")
)

fig, ax = plt.subplots(figsize=(25, 5))

# NaNs in max_strain_samples, mean that the [int] type order doesn't work.
_max_strain_samples_order = np.array(max_strain_samples_order).astype(float)


sns.swarmplot(data=d1.reset_index(), x='max_strain_samples', y='delta', order=_max_strain_samples_order, ax=ax)
plt.axhline(0, color='k', lw=1, linestyle='--')
# # plt.legend(bbox_to_anchor=(1, 1))

d2 = d1.unstack("max_strain_samples")
for max_strain_samples in max_strain_samples_order:
    print(sp.stats.wilcoxon(d2[max_strain_samples].dropna()), d2[max_strain_samples].mean(), d2[max_strain_samples].median())

In [None]:
d1.reset_index()[lambda x: x.max_strain_samples == -1].sort_values('delta')

In [None]:
d = (
    xjin_benchmarking
    [lambda x: True
     & ~x.to_drop
     & x.unit.isin(['uhgg'])
     & x.tool.isin(['spgc'])
     & (x.specgene_params == "ref-t25-p95")
     & x.thresh_params.isin(["corr350-depth250"])
     & x.trnsfm_exponent.isin([0.5, 1.0, 2.0, 3.0])
    ]
)

# NaNs in max_strain_samples, mean that the [int] type order doesn't work.
_max_strain_samples_order = np.array(max_strain_samples_order).astype(float)

fig, axs = plt.subplots(3, figsize=(25, 15))

for y, ax in zip(["precision", "recall", "f1"], axs):
    sns.swarmplot(data=d, x='max_strain_samples', y=y, hue='trnsfm_exponent', order=_max_strain_samples_order, ax=ax, dodge=True)
    ax.legend(bbox_to_anchor=(1, 1))
    ax.set_ylim(-0.05, 1.05)
# plt.yscale('logit')
# sns.swarmplot(data=d, x='max_strain_samples', y='f1', order=max_strain_samples_order)
# sns.box(data=d, x='max_strain_samples', y='f1', order=max_strain_samples_order)

In [None]:
d0 = (
    xjin_benchmarking
    [lambda x: True
     & ~x.to_drop
     & x.unit.isin(['uhgg'])
     & x.tool.isin(['spgc'])
     & (x.specgene_params == "ref-t25-p95")
     # & x.thresh_params.isin(["corr350-depth250"])
    ]
    .set_index(['genome_id', 'correlation_thresh', 'depth_thresh', 'specgene_params', 'max_strain_samples', 'trnsfm_exponent'])
    # .f1
    .xs(
        (0.25, 'ref-t25-p95'),
        level=('depth_thresh', 'specgene_params')
    )
)



d1 = d0.unstack(['correlation_thresh', 'trnsfm_exponent', 'max_strain_samples'])
print(d1.columns.unique().to_frame().apply(lambda x: x.unique()))


score = 'f1'
c = 'strain_depth_sum'
# plt.scatter(d1[(score, 0.35, 1.0, 20)], d1[(score, 0.20, 3.0, -1)], c=d1[('strain_depth_sum', 0.25, 2.0, 20)], norm=mpl.colors.SymLogNorm(linthresh=1))

fig, axs = plt.subplots(2, 3, figsize=(25, 15))

ax = axs.flatten()[0]
ax.set_title('Improves on single-sample depth-only approach.')
artist = ax.scatter(d1[(score, 0.0, 1.0, 1)], d1[(score, 0.2, 3.0, -1)], c=d1[(c, 0.25, 2.0, 20)], norm=mpl.colors.SymLogNorm(linthresh=1))
ax.plot([0, 1], [0, 1])
fig.colorbar(artist, shrink=0.3)
ax.set_aspect(1)

ax = axs.flatten()[1]
ax.set_title('Improves on BEST depth-only performance (10 samples).')
artist = ax.scatter(d1[(score, 0.0, 1.0, 10)], d1[(score, 0.2, 3.0, -1)], c=d1[(c, 0.25, 2.0, 20)], norm=mpl.colors.SymLogNorm(linthresh=1))
ax.plot([0, 1], [0, 1])
fig.colorbar(artist, shrink=0.3)
ax.set_aspect(1)

ax = axs.flatten()[2]
ax.set_title('Similar performance to best untransformed approach at 20 samples')
artist = ax.scatter(d1[(score, 0.35, 1.0, 20)], d1[(score, 0.2, 3.0, 20)], c=d1[(c, 0.25, 2.0, 20)], norm=mpl.colors.SymLogNorm(linthresh=1))
ax.plot([0, 1], [0, 1])
fig.colorbar(artist, shrink=0.3)
ax.set_aspect(1)

ax = axs.flatten()[3]
ax.set_title("Robust to number of samples.")
artist = ax.scatter(d1[(score, 0.2, 3.0, 20)], d1[(score, 0.2, 3.0, -1)], c=d1[(c, 0.25, 2.0, 20)], norm=mpl.colors.SymLogNorm(linthresh=1))
ax.plot([0, 1], [0, 1])
fig.colorbar(artist, shrink=0.3)
ax.set_aspect(1)

ax = axs.flatten()[4]
ax.set_title("Robust to higher cutoff.")
artist = ax.scatter(d1[(score, 0.2, 3.0, -1)], d1[(score, 0.25, 3.0, -1)], c=d1[(c, 0.25, 2.0, 20)], norm=mpl.colors.SymLogNorm(linthresh=1))
ax.plot([0, 1], [0, 1])
fig.colorbar(artist, shrink=0.3)
ax.set_aspect(1)

ax = axs.flatten()[5]
ax.set_title("Robust to lower cutoff.")
artist = ax.scatter(d1[(score, 0.2, 3.0, -1)], d1[(score, 0.15, 3.0, -1)], c=d1[(c, 0.25, 2.0, 20)], norm=mpl.colors.SymLogNorm(linthresh=1))
ax.plot([0, 1], [0, 1])
fig.colorbar(artist, shrink=0.3)
ax.set_aspect(1)

In [None]:
d0 = (
    xjin_benchmarking
    [lambda x: True
     & ~x.to_drop
     & x.unit.isin(['uhgg'])
     & x.tool.isin(['spgc'])
     & (x.specgene_params == "ref-t25-p95")
     # & x.thresh_params.isin(["corr350-depth250"])
    ]
    .set_index(['genome_id', 'correlation_thresh', 'depth_thresh', 'specgene_params', 'max_strain_samples', 'trnsfm_exponent'])
    # .f1
    .xs(
        (0.25, 'ref-t25-p95'),
        level=('depth_thresh', 'specgene_params')
    )
)



d1 = d0.unstack(['correlation_thresh', 'trnsfm_exponent', 'max_strain_samples'])

score = 'f1'
c = 'strain_depth_sum'
x = d1[(score, 0.0, 1.0, -1)]
y = d1[(score, 0.2, 3.0, -1)]
z = d1[(c, 0.2, 3.0, -1)]

print(sp.stats.wilcoxon(x, y))
print((y - x).quantile([0.1, 0.25, 0.75, 0.9]))

In [None]:
plt.scatter(z, (y - x))
plt.xscale('log')
# plt.yscale('symlog', linthresh=0.1)
plt.axhline(0, lw=1, linestyle='--', color='k')
plt.ylim(-1e0, +1e0)

In [None]:
(y - x).mean(), (y - x).median(), ((y - x) > 0).mean()

In [None]:
d0 = (
    xjin_benchmarking
    [lambda x: True
     & ~x.to_drop
     & x.unit.isin(['uhgg'])
     & x.tool.isin(['spgc'])
     & (x.specgene_params == "ref-t25-p95")
     # & x.thresh_params.isin(["corr350-depth250"])
    ]
    .set_index(['depth_thresh', 'specgene_params', 'genome_id', 'correlation_thresh', 'max_strain_samples', 'trnsfm_exponent'])
    # .f1
    .xs(
        (0.25, 'ref-t25-p95'),
        level=('depth_thresh', 'specgene_params')
    )
)



d1 = d0.unstack(['correlation_thresh', 'trnsfm_exponent', 'max_strain_samples'])

fig, axs = plt.subplots(1, 3, figsize=(20, 4.5))
for score, ax in zip(['precision', 'recall', 'f1'], axs.flatten()):
    c = 'strain_depth_sum'
    x = d1[(score, 0.0, 1.0, -1)]
    y = d1[(score, 0.2, 3.0, -1)]
    z = d1[(c, 0.2, 3.0, -1)]
    artist = ax.scatter(x, y, c=z, norm=mpl.colors.SymLogNorm(linthresh=1))
    ax.plot([0, 1], [0, 1])
    ax.set_title(score)
    fig.colorbar(artist)

In [None]:
unit = 'uhgg'

d0 = (
    xjin_benchmarking
    [lambda x: ~x.to_drop & x.unit.isin([unit])]
)


selector_panphlan = lambda x: (
    (x.tool == 'panphlan') &
    True
)
selector_spanda = lambda x: (
    (x.tool == 'spanda') &
    True
)
selector_spgc1 = lambda x: (
    (x.tool == 'spgc') &
    (x.specgene_params == 'ref-t25-p95') &
    (x.depth_thresh == 0.25) &
    (x.correlation_thresh == 0) &
    (x.max_strain_samples == -1) &
    (x.trnsfm_exponent == 1) &
    True
)
selector_spgc2 = lambda x: (
    (x.tool == 'spgc') &
    (x.specgene_params == 'ref-t25-p95') &
    (x.depth_thresh == 0.25) &
    (x.correlation_thresh == 0.20) &
    (x.max_strain_samples == -1) &
    (x.trnsfm_exponent == 3) &
    True
)

score = 'f1'
d1 = pd.DataFrame(dict(
    spgc1=d0[selector_spgc1].set_index('genome_id')[score],
    spgc2=d0[selector_spgc2].set_index('genome_id')[score],
    panphlan=d0[selector_panphlan].set_index('genome_id')[score],
    spanda=d0[selector_spanda].set_index('genome_id')[score],
)).fillna(0)

fig, axs = plt.subplots(2, 2, figsize=(10, 10), sharex=True, sharey=True)
axs = axs.flatten()

ax = axs[0]
ax.scatter('panphlan', 'spanda', data=d1)
ax.plot([0, 1], [0, 1])

ax = axs[1]
ax.scatter('panphlan', 'spgc1', data=d1)
ax.plot([0, 1], [0, 1])

ax = axs[2]
ax.scatter('panphlan', 'spgc2', data=d1)
ax.plot([0, 1], [0, 1])

ax = axs[3]
ax.scatter('spgc1', 'spgc2', data=d1)
ax.plot([0, 1], [0, 1])

# ax.set_xlim(0.5, 1.05)
# ax.set_ylim(0.5, 1.05)

In [None]:
unit = 'cog'

d0 = (
    xjin_benchmarking
    [lambda x: ~x.to_drop & x.unit.isin([unit])]
)


selector_panphlan = lambda x: (
    (x.tool == 'panphlan') &
    True
)
selector_spanda = lambda x: (
    (x.tool == 'spanda') &
    True
)
selector_spgc1 = lambda x: (
    (x.tool == 'spgc') &
    (x.specgene_params == 'ref-t25-p95') &
    (x.depth_thresh == 0.25) &
    (x.correlation_thresh == 0) &
    (x.max_strain_samples == -1) &
    (x.trnsfm_exponent == 1) &
    True
)
selector_spgc2 = lambda x: (
    (x.tool == 'spgc') &
    (x.specgene_params == 'ref-t25-p95') &
    (x.depth_thresh == 0.25) &
    (x.correlation_thresh == 0.20) &
    (x.max_strain_samples == -1) &
    (x.trnsfm_exponent == 3) &
    True
)

score = 'f1'
d1 = pd.DataFrame(dict(
    panphlan=d0[selector_panphlan].set_index('genome_id')[score],
    spanda=d0[selector_spanda].set_index('genome_id')[score],
    spgc1=d0[selector_spgc1].set_index('genome_id')[score],
    spgc2=d0[selector_spgc2].set_index('genome_id')[score],
    strain_depth_max=d0[selector_spgc2].set_index('genome_id')['strain_depth_max'],
    strain_depth_sum=d0[selector_spgc2].set_index('genome_id')['strain_depth_sum'],
)).fillna(0)

fig, axs = plt.subplots(2, 2, figsize=(10, 10), sharex=True, sharey=True)

for tool, ax in zip(['spanda', 'panphlan', 'spgc1', 'spgc2'], axs.flatten()):
    ax.scatter('strain_depth_max', tool, data=d1, alpha=0.5)
    ax.set_xscale('symlog', linthresh=0.1)
    ax.set_title(tool)
    ax.axvline(1.0, lw=1, linestyle='--', color='k')
    ax.axhline(0.9, lw=1, linestyle=':', color='k')
    ax.axhline(0.8, lw=1, linestyle='--', color='k')
# ax.set_xlim(0.5, 1.05)
# ax.set_ylim(0.5, 1.05)


fig, ax = plt.subplots()
ax.scatter('strain_depth_sum', 'spgc2', data=d1, alpha=0.5)
ax.set_xscale('symlog', linthresh=0.1)
ax.axvline(1.0, lw=1, linestyle='--', color='k')
ax.axhline(0.9, lw=1, linestyle=':', color='k')
ax.axhline(0.8, lw=1, linestyle='--', color='k')

# TODO: strain_depth_max and strain_depth_sum is wrong when sfacts doesn't detect a single strain in xjin samples.

In [None]:
d1[lambda x: (x.spgc2 < 0.5)]

In [None]:
thresh = 0.5
d1[lambda x: (x.panphlan < thresh) & (x.spanda < thresh)]

In [None]:
d0[selector_spgc2].path.iloc[0]