### Preamble

#### Project Template

In [None]:
%load_ext autoreload

In [None]:
import os as _os
_os.chdir(_os.environ['PROJECT_ROOT'])
_os.path.realpath(_os.path.curdir)

#### Imports

In [None]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
from lib.pandas_util import idxwhere, aligned_index, align_indexes, invert_mapping
import matplotlib as mpl
import lib.plot
import statsmodels as sm
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm
import subprocess
from tempfile import mkstemp
import time
import subprocess
from itertools import chain
import os
from itertools import product
from mpl_toolkits.axes_grid1 import make_axes_locatable


In [None]:
import sfacts as sf

In [None]:
import lib.thisproject.data

#### Set Style

In [None]:
sns.set_context('talk')
plt.rcParams['figure.dpi'] = 50

## Set Parameters / Load Data

In [None]:
# Fixed params

group='xjin_hmp2'
stemA='r.proc'
centroid=75
stemB = 'filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts37-s85-seed0'
gene_params = f"99-v22-agg{centroid}"
# depth_thresh="250"
# specgene_params='specgene-ref-t25-p90'

In [None]:
ref_strains = pd.read_table('meta/genome.tsv', index_col='genome_id')[lambda x: ~x.genome_path.isna()]
species_list = ref_strains.species_id.unique()

In [None]:
# TODO: Make these loop variables
seed = 0
max_strain_samples_order = [1, 2, 3, 5, 10, 20, 50, -1]

xjin_benchmarking = []
for species in tqdm(species_list):
    for specgene_params in ["ref-t25-p85", "ref-t25-p90", "ref-t25-p95"]:
        for thresh_params in ["corr0-depth250", "corr350-depth250", "alpha50", "alpha100", "alpha200"]:
            for max_strain_samples in max_strain_samples_order:
                if max_strain_samples == -1:
                    ss_params = "all"
                    seed_list = [0]
                else:
                    ss_params = f"deepest-n{max_strain_samples}"
                path = f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.gene{gene_params}.spgc_specgene-{specgene_params}_ss-xjin-{ss_params}_thresh-{thresh_params}.xjin_strain_summary.tsv"
                if not os.path.exists(path):
                    # print(path)
                    continue

                xjin_benchmarking.append(pd.read_table(path).assign(
                    species=species,
                    specgene_params=specgene_params,
                    seed=seed,
                    max_strain_samples=max_strain_samples,
                    thresh_params=thresh_params,
                ))

xjin_benchmarking = pd.concat(xjin_benchmarking).assign(
    # Reasonable filters:
    to_drop=lambda x: ( False
        | (x.total_num_reference_genomes > 1)
        | (x.strain != x.dominant_strain)
        # | ((x.num_strain_samples != x.max_strain_samples) & (x.max_strain_samples > 0))
    )
)

xjin_benchmarking

In [None]:
bins = np.linspace(0, 1, num=51)



d = (
    xjin_benchmarking
    [lambda x: True
     & ~x.to_drop
     & (x.specgene_params == "ref-t25-p95")
     & (x.max_strain_samples == -1)
     & (x.thresh_params == "corr0-depth250")
    ]
)
plt.hist(d.f1, bins=bins, alpha=0.5)
plt.hist(d[lambda x: x.strain_depth_sum < 0.5].f1, bins=bins, alpha=0.5)

None

In [None]:
bins = np.linspace(0.4, 1, num=500)

fig, ax = plt.subplots()

for thresh_params, max_strain_samples, in product(["corr0-depth250", "corr350-depth250"], [1, 10]):
    d = (
        xjin_benchmarking
        [lambda x: True
         & ~x.to_drop
         & (x.specgene_params == "ref-t25-p95")
         & (x.max_strain_samples == max_strain_samples)
         & (x.thresh_params == thresh_params)
        ]
        .sort_values('f1')
    )
    # plt.hist(d[lambda x: x.strain_depth_sum >= 0.5].f1, bins=bins, cumulative=True, alpha=0.5, lw=5, label="Depth-only (1)", histtype="step")
    ax.plot(d.f1, 1 - d.assign(_one=1)._one.cumsum() / d.assign(_one=1)._one.sum(), label=(thresh_params, max_strain_samples))

# ax.set_xscale('log')
# ax.invert_xaxis()
ax.legend(bbox_to_anchor=(1, 1))
ax.set_xlabel('f1')
ax.set_ylabel('frac strains')
# d = (
#     xjin_benchmarking
#     [lambda x: True
#      & ~x.to_drop
#      & (x.specgene_params == "ref-t25-p95")
#      & (x.max_strain_samples == 10)
#      & (x.thresh_params == "corr0-depth250")
#     ]
# )
# plt.hist(d[lambda x: x.strain_depth_sum >= 0.5].f1, bins=bins, cumulative=True, alpha=0.5, lw=5, label="Depth-only (10)", histtype="step")

# d = (
#     xjin_benchmarking
#     [lambda x: True
#      & ~x.to_drop
#      & (x.specgene_params == "ref-t25-p95")
#      & (x.max_strain_samples == 1)
#      & (x.thresh_params == "corr350-depth250")
#     ]
# )
# plt.hist(d[lambda x: x.strain_depth_sum >= 0.5].f1, bins=bins, cumulative=True, alpha=0.5, lw=5, label="SPGC (1)", histtype="step")
# None

# d = (
#     xjin_benchmarking
#     [lambda x: True
#      & ~x.to_drop
#      & (x.specgene_params == "ref-t25-p95")
#      & (x.max_strain_samples == 10)
#      & (x.thresh_params == "corr350-depth250")
#     ]
# )
# plt.hist(d[lambda x: x.strain_depth_sum >= 0.5].f1, bins=bins, cumulative=True, alpha=0.5, lw=5, label="SPGC (10)", histtype="step")


# # d = (
# #     xjin_benchmarking
# #     [lambda x: True
# #      & ~x.to_drop
# #      & (x.specgene_params == "ref-t25-p95")
# #      & (x.max_strain_samples == -1)
# #      & (x.thresh_params == "corr350-depth250")
# #     ]
# # )
# # plt.hist(d[lambda x: x.strain_depth_sum >= 0.5].f1, bins=bins, alpha=0.5, label="SPGC (all)", histtype="step", cumulative=True)


# plt.legend()

In [None]:
d = (
    xjin_benchmarking
    [lambda x: True
     & ~x.to_drop
     & (x.specgene_params == "ref-t25-p95")
     & x.thresh_params.isin(["corr0-depth250", "corr350-depth250"])
    ]
)

fig, axs = plt.subplots(3, figsize=(25, 15))

for y, ax in zip(["precision", "recall", "f1"], axs):
    sns.swarmplot(data=d, x='max_strain_samples', y=y, hue='thresh_params', order=max_strain_samples_order, ax=ax, dodge=True)
    ax.legend(bbox_to_anchor=(1, 1))
    ax.set_ylim(-0.05, 1.05)
# plt.yscale('logit')
# sns.swarmplot(data=d, x='max_strain_samples', y='f1', order=max_strain_samples_order)
# sns.box(data=d, x='max_strain_samples', y='f1', order=max_strain_samples_order)

In [None]:
d0 = (
    xjin_benchmarking
    [lambda x: ~x.to_drop & (x.specgene_params == "ref-t25-p95")]
    .set_index(['species', 'genome_id', 'thresh_params', 'max_strain_samples'])
    .f1
    .unstack('thresh_params')
    .assign(delta=lambda x: x["corr350-depth250"] - x["corr0-depth250"])
    .delta
)

fig, ax = plt.subplots(figsize=(25, 5))

sns.swarmplot(data=d0.reset_index(), x='max_strain_samples', y='delta', order=max_strain_samples_order, ax=ax)
plt.axhline(0, color='k', lw=1, linestyle='--')
# plt.legend(bbox_to_anchor=(1, 1))

d1 = d0.unstack("max_strain_samples")
for max_strain_samples in max_strain_samples_order:
    print(sp.stats.wilcoxon(d1[max_strain_samples].dropna()), d1[max_strain_samples].mean(), d1[max_strain_samples].median())

In [None]:
d1[50].sort_values()

In [None]:
d0 = (
    xjin_benchmarking
    [lambda x: ~x.to_drop & (x.specgene_params == "ref-t25-p95")]
    .set_index(['species', 'genome_id', 'thresh_params', 'max_strain_samples'])
    .f1
    .unstack(['thresh_params', 'max_strain_samples'])
    .apply(lambda y: y - x["corr0-depth250", 1])
    .stack(['thresh_params', 'max_strain_samples'])
    .rename("delta")
    .xs("corr350-depth250", level="thresh_params")
)

fig, ax = plt.subplots(figsize=(25, 5))

sns.swarmplot(data=d0.reset_index(), x='max_strain_samples', y='delta', order=max_strain_samples_order, ax=ax)
plt.axhline(0, color='k', lw=1, linestyle='--')
# # plt.legend(bbox_to_anchor=(1, 1))

d1 = d0.unstack("max_strain_samples")
for max_strain_samples in max_strain_samples_order:
    print(sp.stats.wilcoxon(d1[max_strain_samples].dropna()), d1[max_strain_samples].mean(), d1[max_strain_samples].median())

In [None]:
fig, axs = plt.subplots(3, figsize=(25, 15))


for score, ax in zip(["precision", "recall" , "f1"], axs.flatten()):
    d0 = (
        xjin_benchmarking
        [lambda x: ~x.to_drop & (x.specgene_params == "ref-t25-p95")]
        .set_index(['species', 'genome_id', 'thresh_params', 'max_strain_samples'])
        [score]
        .unstack(['thresh_params', 'max_strain_samples'])
        .apply(lambda y: y - x["corr0-depth250", 10])
        .stack(['thresh_params', 'max_strain_samples'])
        .rename("delta")
        .xs("corr350-depth250", level="thresh_params")
    )
    sns.swarmplot(data=d0.reset_index(), x='max_strain_samples', y='delta', order=max_strain_samples_order, ax=ax)
    ax.axhline(0, color='k', lw=1, linestyle='--')
    ax.set_title(score)
    # # plt.legend(bbox_to_anchor=(1, 1))

    d1 = d0.unstack("max_strain_samples")
    for max_strain_samples in max_strain_samples_order:
        print(score, sp.stats.wilcoxon(d1[max_strain_samples].dropna()), d1[max_strain_samples].mean(), d1[max_strain_samples].median())
fig.tight_layout()

In [None]:
fig, axs = plt.subplots(3, figsize=(25, 15))


for score, ax in zip(["precision", "recall" , "f1"], axs.flatten()):
    d0 = (
        xjin_benchmarking
        [lambda x: ~x.to_drop & (x.specgene_params == "ref-t25-p95")]
        .set_index(['species', 'genome_id', 'thresh_params', 'max_strain_samples'])
        [score]
        .unstack(['thresh_params', 'max_strain_samples'])
        .apply(lambda y: y - x["corr0-depth250", 1])
        .stack(['thresh_params', 'max_strain_samples'])
        .rename("delta")
        .xs("corr350-depth250", level="thresh_params")
    )
    sns.swarmplot(data=d0.reset_index(), x='max_strain_samples', y='delta', order=max_strain_samples_order, ax=ax)
    ax.axhline(0, color='k', lw=1, linestyle='--')
    ax.set_title(score)
    # # plt.legend(bbox_to_anchor=(1, 1))

    d1 = d0.unstack("max_strain_samples")
    for max_strain_samples in max_strain_samples_order:
        print(score, sp.stats.wilcoxon(d1[max_strain_samples].dropna()), d1[max_strain_samples].mean(), d1[max_strain_samples].median())
fig.tight_layout()

In [None]:
d = (
    xjin_benchmarking
    [lambda x: True
     & ~x.to_drop
     & (x.specgene_params == "ref-t25-p95")
     & (x.max_strain_samples == -1)
     & (x.thresh_params == "corr350-depth250")
    ]
)

d.sort_values("f1", ascending=False).head(50)
# d[d.species == "102395"]

In [None]:
xjin_benchmarking[lambda x: x.species == "102395"].sort_values('f1', ascending=False).iloc[:3].T#[lambda x: True
                #  & x.species == "102454"]

In [None]:
d = (
    (
        xjin_benchmarking
        [lambda x: True
         & ~x.to_drop
         & (x.specgene_params == "ref-t25-p95")
         & (x.max_strain_samples == 1)
         & (x.thresh_params == "corr0-depth250")
        ]
    )
    .merge(
        (
        xjin_benchmarking
        [lambda x: True
         & ~x.to_drop
         & (x.specgene_params == "ref-t25-p95")
         & (x.max_strain_samples == 1)
         & (x.thresh_params == "corr350-depth250")
        ]
        ),
    on=["genome_id"],
    suffixes=("", "_alt"),
    )
    # .merge(
    #     (
    #     xjin_benchmarking
    #     [lambda x: True
    #      & ~x.to_drop
    #      & (x.specgene_params == "ref-t25-p85")
    #      & (x.max_strain_samples == -1)
    #      & (x.thresh_params == "alpha200")
    #     ]
    #     ),
    # on=["genome_id"],
    # suffixes=("", "_alt2"),
    # )
)

fig, axs = plt.subplots(3, figsize=(6, 15))

for score, ax in zip(['precision', 'recall', 'f1'], axs.flatten()):
    cax = make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05)
    artist = ax.scatter(f'{score}', f'{score}_alt', c="strain_depth_sum", data=d.sort_values("strain_depth_sum"), norm=mpl.colors.LogNorm())
    cbar = fig.colorbar(artist, cax=cax)
    ax.plot([0.01, 0.99], [0.01, 0.99])
    ax.set_title(score)
    ax.set_aspect(1)
    # plt.colorbar()
    # ax.set_xscale('logit')
    # ax.set_yscale('logit')
fig.tight_layout()