In [None]:
%load_ext watermark


In [None]:
from IPython.display import display, HTML
from hstrat import _auxiliary_lib as hstrat_aux
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from scipy import stats as scipy_stats
from slugify import slugify
from teeplot import teeplot as tp
from tqdm import tqdm

from pylib._calc_normed_defmut_clade_stats import (
    calc_normed_defmut_clade_stats,
)
from pylib._mask_sequence_diffs import mask_sequence_diffs
from pylib._screen_mutation_defined_nodes_sequence_diff import (
    screen_mutation_defined_nodes_sequence_diff,
)
from pylib._seed_global_rngs import seed_global_rngs


In [None]:
%watermark -diwmuv -iv


In [None]:
teeplot_subdir = "2025-05-16-vanilla-comparator"
teeplot_subdir


In [None]:
seed_global_rngs(1)


## Get Data


In [None]:
df = pd.read_parquet("https://osf.io/f4qaj/download").astype(
    {"origin_time": float},
)

df.head()


In [None]:
def stripboxen_plot(data: pd.DataFrame, x: str, y: str, hue: str) -> plt.Axes:
    ax = sns.boxenplot(
        data=data,
        y=y,
        x=x,
        hue=hue,
        legend=False,
    )
    sns.barplot(
        data=data,
        y=y,
        x=x,
        hue=hue,
        alpha=0.0,
        ax=ax,
        legend=False,
    )
    sns.stripplot(
        y=data[y] + np.random.uniform(-1, 1, len(data)),
        x=data[x],
        alpha=0.2,
        ax=ax,
        color="k",
        legend=False,
        jitter=0.3,
        size=4,
    )
    return ax


In [None]:
for trt_name, group in df.groupby("trt_name", observed=True):
    dfx = group[
        (group["replicate_uuid"] == group["replicate_uuid"].unique()[0])
    ]
    dfx = hstrat_aux.alifestd_to_working_format(dfx).reset_index(drop=True)

    # yield (mut_char_pos, mut_char_ref, mut_char_var), mut_mask
    mutations = mask_sequence_diffs(
        ancestral_sequence=dfx["ancestral_sequence"]
        .astype(str)
        .unique()
        .item(),
        sequence_diffs=dfx["sequence_diff"],
        sparsify_mask=False,
    )
    mutations = [*mutations]

    defining_masks = {
        (
            mut_char_pos,
            mut_char_ref,
            mut_char_var,
        ): screen_mutation_defined_nodes_sequence_diff(
            phylo_df=dfx,
            mut_char_pos=mut_char_pos,
            mut_char_var=mut_char_var,
        )
        for (mut_char_pos, mut_char_ref, mut_char_var), mut_mask in mutations
    }
    dfx["is_focal_mutant"] = next(iter(defining_masks.values()))

    dfx = calc_normed_defmut_clade_stats(
        phylo_df=dfx,
        defmut_clade_masks=defining_masks,
        match_cols=["variant_flavor"],
        ot_deltas=(4, 7, 14, 28, 44),
        progress_wrap=tqdm,
    )

    dfx_ = dfx.copy()
    dfx_["is_focal_mutant"] = "null"
    data = pd.concat([dfx, dfx_], ignore_index=True)

    for y in (
        "defmut_norm_all-num_leaves",
        "defmut_norm_ot_delta:4-num_leaves",
        "defmut_norm_ot_delta:7-num_leaves",
        "defmut_norm_ot_delta:14-num_leaves",
        "defmut_norm_ot_delta:28-num_leaves",
        "defmut_norm_ot_delta:44-num_leaves",
        "defmut_norm_match:variant_flavor-num_leaves",
        "defmut_norm_all-clade_duration",
        "defmut_norm_ot_delta:4-clade_duration",
        "defmut_norm_ot_delta:7-clade_duration",
        "defmut_norm_ot_delta:14-clade_duration",
        "defmut_norm_ot_delta:28-clade_duration",
        "defmut_norm_ot_delta:44-clade_duration",
        "defmut_norm_match:variant_flavor-clade_duration",
    ):
        display(HTML(f"<h2>{trt_name} {y}</h2>"))
        with tp.teed(
            stripboxen_plot,
            data=data,
            x="is_focal_mutant",
            y=y,
            hue="is_focal_mutant",
            teeplot_outattrs={
                "trt_name": slugify(trt_name),
            },
            teeplot_subdir=teeplot_subdir,
        ) as teed:
            ax = teed
            hue = "is_focal_mutant"
            threshold = 50
            teed.axhline(50, color="k", linestyle="--")

            n_boot = 10_000
            null_p = (
                data.loc[
                    (data[hue] == "null")
                    & (data[y] != threshold)
                    & ~np.isnan(data[y]),
                    y,
                ]
                < threshold
            ).mean()

            for j, h_cat in enumerate(data[hue].unique().tolist()):
                grp = data[(data[hue] == h_cat)][y].dropna()
                n = len(grp)
                if n == 0:
                    continue

                # --- bootstrap test for mean < threshold ---
                boot_means = np.random.choice(
                    grp, size=(n_boot, n), replace=True
                ).mean(axis=1)
                p_boot = np.mean(boot_means >= threshold)

                # --- binomial/sign test for median < threshold ---
                k = np.sum(grp < threshold)
                n = np.sum(grp != threshold & ~np.isnan(grp))
                p_binom = scipy_stats.binomtest(
                    k, n, p=null_p, alternative="greater"
                ).pvalue

                grp_max = grp.max()
                text = (
                    f"n={n}\n"
                    f"binom<50: {p_binom:.3f}\n"
                    f"mean<50: {p_boot:.3f}"
                )
                ax.text(j, 5, text, ha="center", va="bottom", fontsize="small")
